1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
17 #include "tensorflow/cc/ops/nn_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/devices.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/remapper.h"
23 #include "tensorflow/core/grappler/utils/grappler_test.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/mkl_util.h"
27
28 namespace tensorflow {
29 namespace grappler {
30
31 class MklRemapperTest : public GrapplerTest {
32 public:
33 const string kAddNOp = "AddN";
34 const string kAddOp = "Add";
35 const string kAddV2Op = "AddV2";
36
37 protected:
FuseConv2DWithBiasAndAddNOrAdd(const string & data_format,const string & activation,string add_op,bool add_with_bcast)38 void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format,
39 const string& activation, string add_op,
40 bool add_with_bcast) {
41 using ::tensorflow::ops::Placeholder;
42
43 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44
45 auto input_shape = (data_format == "NHWC")
46 ? ops::Placeholder::Shape({8, 32, 32, 3})
47 : ops::Placeholder::Shape({8, 3, 32, 32});
48 auto input_shape_addn = ops::Placeholder::Shape({});
49 if (data_format == "NHWC") {
50 if (add_with_bcast)
51 input_shape_addn = ops::Placeholder::Shape({128});
52 else
53 input_shape_addn = ops::Placeholder::Shape({8, 32, 32, 128});
54 } else {
55 if (add_with_bcast)
56 input_shape_addn = ops::Placeholder::Shape({32});
57 else
58 input_shape_addn = ops::Placeholder::Shape({8, 128, 32, 32});
59 }
60 auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
61 auto bias_shape = ops::Placeholder::Shape({128});
62
63 auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
64 auto input_addn =
65 Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn);
66 auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
67 auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
68
69 std::vector<int> strides = {1, 1, 1, 1};
70 auto conv =
71 ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME",
72 ops::Conv2D::Attrs().DataFormat(data_format));
73 auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias,
74 ops::BiasAdd::Attrs().DataFormat(data_format));
75
76 auto addfetch = [&](::tensorflow::Input addop) {
77 auto activate = s.WithOpName("activation");
78 auto fetch = s.WithOpName("fetch");
79 if (activation == "Relu") {
80 ops::Identity(fetch, ops::Relu(activate, addop));
81 } else if (activation == "Relu6") {
82 ops::Identity(fetch, ops::Relu6(activate, addop));
83 } else if (activation == "Elu") {
84 ops::Identity(fetch, ops::Elu(activate, addop));
85 } else if (activation == "LeakyRelu") {
86 ops::Identity(fetch, ops::internal::LeakyRelu(activate, addop));
87 } else {
88 DCHECK(activation == "None");
89 ops::Identity(fetch, addop);
90 }
91 };
92
93 if (add_op == kAddNOp) {
94 auto addn = ops::AddN(s.WithOpName(add_op),
95 std::initializer_list<Input>{input_addn, bias_add});
96 addfetch(addn);
97 } else if (add_op == kAddV2Op) {
98 auto add = ops::AddV2(s.WithOpName(add_op), input_addn, bias_add);
99 addfetch(add);
100 } else {
101 auto add = ops::Add(s.WithOpName(add_op), input_addn, bias_add);
102 addfetch(add);
103 }
104 auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
105 TensorShape(input_shape.shape_.dim_sizes()));
106 auto input_addn_tensor = GenerateRandomTensor<DT_FLOAT>(
107 TensorShape(input_shape_addn.shape_.dim_sizes()));
108 auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
109 TensorShape(filter_shape.shape_.dim_sizes()));
110 auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
111 TensorShape(bias_shape.shape_.dim_sizes()));
112
113 GrapplerItem item;
114 item.fetch = {"fetch"};
115 item.feed = {{"input", input_tensor},
116 {"filter", filter_tensor},
117 {"bias", bias_tensor},
118 {"input_addn", input_addn_tensor}};
119 TF_CHECK_OK(s.ToGraphDef(&item.graph));
120
121 // Place all nodes on CPU.
122 for (int i = 0; i < item.graph.node_size(); ++i) {
123 item.graph.mutable_node(i)->set_device("/device:CPU:0");
124 }
125
126 // Set Rewriter config to AGGRESSIVE so that we can use Placeholder shape
127 // to test that Add with both inputs having same shape get fused with
128 // Conv2D. Setting this config to AGGRESSIVE is not required for the feature
129 // though.
130 Remapper optimizer(RewriterConfig::AGGRESSIVE);
131 GraphDef output;
132 TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
133
134 bool check_fusion = !add_with_bcast;
135 int found = 0;
136 for (const NodeDef& node : output.node()) {
137 auto fetch_node_name = activation != "None" ? "activation" : add_op;
138 if (node.name() == fetch_node_name) {
139 if (check_fusion) {
140 EXPECT_EQ("_FusedConv2D", node.op());
141 EXPECT_EQ("input", node.input(0));
142 EXPECT_EQ("filter", node.input(1));
143
144 EXPECT_EQ(2, node.attr().at("num_args").i());
145 EXPECT_EQ("bias", node.input(2));
146 EXPECT_EQ("input_addn", node.input(3));
147
148 const auto fused_ops = node.attr().at("fused_ops").list().s();
149 if (activation != "None") {
150 EXPECT_EQ(3, fused_ops.size());
151 EXPECT_EQ("BiasAdd", fused_ops[0]);
152 EXPECT_EQ("Add", fused_ops[1]);
153 EXPECT_EQ(activation, fused_ops[2]);
154 } else {
155 EXPECT_EQ(2, fused_ops.size());
156 EXPECT_EQ("BiasAdd", fused_ops[0]);
157 EXPECT_EQ("Add", fused_ops[1]);
158 }
159 } else {
160 if (activation != "None") {
161 EXPECT_EQ(node.op(), activation);
162 ASSERT_EQ(node.input_size(), 1);
163 EXPECT_EQ(node.input(0), add_op);
164 } else {
165 EXPECT_EQ(node.op(), add_op);
166 ASSERT_EQ(node.input_size(), 2);
167 }
168 }
169 found++;
170 }
171 }
172 EXPECT_EQ(1, found);
173
174 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
175 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
176 EXPECT_EQ(1, tensors_expected.size());
177 EXPECT_EQ(1, tensors.size());
178 // Using relative tolerance since oneDNN could produce different results
179 // when float32 numbers need to be rounded during accumulation.
180 test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
181 }
182 };
183
184 #define CREATE_CONV2DFUSION_TEST(data_format, addop, activation, bcast) \
185 TEST_F( \
186 MklRemapperTest, \
187 FuseConv2DWithBiasAnd##addop##_##data_format##_activation##activation##_addbcast##bcast) { \
188 FuseConv2DWithBiasAndAddNOrAdd(#data_format, #activation, #addop, bcast); \
189 }
190
191 #define CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(data_format, addop, bcast) \
192 CREATE_CONV2DFUSION_TEST(data_format, addop, Relu, bcast); \
193 CREATE_CONV2DFUSION_TEST(data_format, addop, Relu6, bcast); \
194 CREATE_CONV2DFUSION_TEST(data_format, addop, Elu, bcast); \
195 CREATE_CONV2DFUSION_TEST(data_format, addop, LeakyRelu, bcast); \
196 CREATE_CONV2DFUSION_TEST(data_format, addop, None, bcast);
197
198 #define CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(addop) \
199 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
200 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false);
201
202 CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(AddN);
203
204 #define CREATE_CONV2DFUSION_ADD_BCAST_TEST(addop) \
205 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
206 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false); \
207 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, true); \
208 CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, true);
209
210 CREATE_CONV2DFUSION_ADD_BCAST_TEST(Add);
211 CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
212
213 #undef CREATE_CONV2DFUSION_ADD_NOBCAST_TEST
214 #undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
215 #undef CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST
216 #undef CREATE_CONV2DFUSION_TEST
217
218 #define REGISTER_TEST(NAME, T, INPUT) \
219 TEST_F(MklRemapperTest, NAME##_##T) { \
220 using ::tensorflow::ops::Placeholder; \
221 \
222 for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) { \
223 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); \
224 \
225 auto input_shape = Placeholder::Shape({8, 32, 32, 3}); \
226 auto filter_shape = Placeholder::Shape({1, 1, 3, 1}); \
227 auto bias_shape = Placeholder::Shape({3}); \
228 \
229 auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); \
230 auto filter = \
231 Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); \
232 auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); \
233 \
234 std::vector<int> strides = {1, 1, 1, 1}; \
235 auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"), \
236 input, filter, strides, "SAME"); \
237 auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); \
238 \
239 ops::Identity fetch = [&]() -> ops::Identity { \
240 auto activate = s.WithOpName("activation"); \
241 auto fetch = s.WithOpName("fetch"); \
242 \
243 if (activation == "Relu") { \
244 return ops::Identity(fetch, ops::Relu(activate, bias_add)); \
245 } else if (activation == "Relu6") { \
246 return ops::Identity(fetch, ops::Relu6(activate, bias_add)); \
247 } else if (activation == "Elu") { \
248 return ops::Identity(fetch, ops::Elu(activate, bias_add)); \
249 } \
250 \
251 DCHECK(activation == "None"); \
252 return ops::Identity(fetch, bias_add); \
253 }(); \
254 \
255 auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3}); \
256 auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1}); \
257 auto bias_t = GenerateRandomTensor<DT_FLOAT>({3}); \
258 \
259 GrapplerItem item; \
260 item.fetch = {"fetch"}; \
261 item.feed = { \
262 {"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; \
263 TF_CHECK_OK(s.ToGraphDef(&item.graph)); \
264 \
265 for (int i = 0; i < item.graph.node_size(); ++i) { \
266 item.graph.mutable_node(i)->set_device("/device:CPU:0"); \
267 } \
268 \
269 Remapper optimizer(RewriterConfig::ON); \
270 GraphDef output; \
271 TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); \
272 \
273 int found = 0; \
274 for (const NodeDef& node : output.node()) { \
275 if (node.name() != "bias_add" && node.name() != "activation") \
276 continue; \
277 \
278 EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative"); \
279 ASSERT_EQ(node.input_size(), 3); \
280 EXPECT_EQ(node.input(0), "input"); \
281 EXPECT_EQ(node.input(1), "filter"); \
282 \
283 EXPECT_EQ(node.attr().at("num_args").i(), 1); \
284 EXPECT_EQ(node.input(2), "bias"); \
285 \
286 const auto fused_ops = node.attr().at("fused_ops").list().s(); \
287 if (node.name() == "bias_add") { \
288 ASSERT_EQ(fused_ops.size(), 1); \
289 EXPECT_EQ(fused_ops[0], "BiasAdd"); \
290 found++; \
291 } \
292 if (node.name() == "activation") { \
293 ASSERT_EQ(fused_ops.size(), 2); \
294 EXPECT_EQ(fused_ops[0], "BiasAdd"); \
295 EXPECT_EQ(fused_ops[1], activation); \
296 found++; \
297 } \
298 } \
299 EXPECT_EQ(found, 1); \
300 \
301 auto tensors_expected = \
302 EvaluateNodes(item.graph, item.fetch, item.feed); \
303 ASSERT_EQ(tensors_expected.size(), 1); \
304 auto tensors = EvaluateNodes(output, item.fetch, item.feed); \
305 ASSERT_EQ(tensors.size(), 1); \
306 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); \
307 } \
308 }
309 REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
310 #undef REGISTER_TEST
311
TEST_F(MklRemapperTest,FuseBatchNormWithRelu)312 TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
313 using ::tensorflow::ops::Placeholder;
314
315 for (bool is_training : {true, false}) {
316 for (bool has_side_input : {true, false}) {
317 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
318
319 const int num_channels = 24;
320
321 TensorShape channel_shape({num_channels});
322 TensorShape empty_shape({0});
323
324 auto input =
325 Placeholder(s.WithOpName("input"), DT_FLOAT,
326 ops::Placeholder::Shape({2, 8, 8, num_channels}));
327 auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
328 auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
329 auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
330 auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
331 auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
332
333 float epsilon = 0.1f;
334 auto fbn =
335 ops::FusedBatchNormV3(s.WithOpName("fused_batch_norm"), input_cast,
336 scale, offset, mean, var,
337 ops::FusedBatchNormV3::IsTraining(is_training)
338 .Epsilon(epsilon)
339 .DataFormat("NHWC"));
340
341 if (has_side_input) {
342 auto side_input =
343 Placeholder(s.WithOpName("side_input"), DT_FLOAT,
344 ops::Placeholder::Shape({2, 8, 8, num_channels}));
345 auto side_input_cast =
346 ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
347 auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
348 auto relu = ops::Relu(s.WithOpName("relu"), add);
349 } else {
350 auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
351 }
352
353 auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
354 auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
355 auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
356 auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
357 : channel_shape);
358 auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
359 : channel_shape);
360 auto side_input_t =
361 GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
362
363 GrapplerItem item;
364 item.fetch = {"relu"};
365 if (has_side_input)
366 item.feed = {{"input", input_t}, {"scale", scale_t},
367 {"offset", offset_t}, {"mean", mean_t},
368 {"var", var_t}, {"side_input", side_input_t}};
369 else
370 item.feed = {{"input", input_t},
371 {"scale", scale_t},
372 {"offset", offset_t},
373 {"mean", mean_t},
374 {"var", var_t}};
375 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
376
377 // Place all nodes on CPU.
378 for (int i = 0; i < item.graph.node_size(); ++i) {
379 item.graph.mutable_node(i)->set_device("/device:CPU:0");
380 }
381
382 Remapper optimizer(RewriterConfig::AGGRESSIVE);
383 GraphDef output;
384 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
385
386 int found = 0;
387 if (has_side_input) {
388 for (const NodeDef& node : output.node()) {
389 if (node.name() == "add") {
390 EXPECT_EQ(node.op(), "Add");
391 ASSERT_EQ(node.input_size(), 2);
392 EXPECT_EQ(node.input(0), "fused_batch_norm");
393 EXPECT_EQ(node.input(1), "side_input_cast");
394 found++;
395 }
396 if (node.name() == "relu") {
397 EXPECT_EQ(node.op(), "Relu");
398 ASSERT_EQ(node.input_size(), 1);
399 EXPECT_EQ(node.input(0), "add");
400 found++;
401 }
402 if (node.name() == "fused_batch_norm") {
403 EXPECT_EQ(node.op(), "FusedBatchNormV3");
404 ASSERT_EQ(node.input_size(), 5);
405 EXPECT_EQ(node.input(0), "input_cast");
406 EXPECT_EQ(node.input(1), "scale");
407 EXPECT_EQ(node.input(2), "offset");
408 EXPECT_EQ(node.input(3), "mean");
409 EXPECT_EQ(node.input(4), "var");
410 found++;
411 }
412 }
413 EXPECT_EQ(found, 3);
414 } else {
415 for (const NodeDef& node : output.node()) {
416 if (node.name() == "relu") {
417 EXPECT_EQ(node.op(), "Identity");
418 ASSERT_EQ(node.input_size(), 1);
419 EXPECT_EQ(node.input(0), "fused_batch_norm");
420 found++;
421 }
422 if (node.name() == "fused_batch_norm") {
423 EXPECT_EQ(node.op(), "_FusedBatchNormEx");
424 ASSERT_EQ(node.input_size(), 5);
425 EXPECT_EQ(node.input(0), "input_cast");
426 EXPECT_EQ(node.input(1), "scale");
427 EXPECT_EQ(node.input(2), "offset");
428 EXPECT_EQ(node.input(3), "mean");
429 EXPECT_EQ(node.input(4), "var");
430
431 auto attr = node.attr();
432 EXPECT_EQ(attr["num_side_inputs"].i(), 0);
433 EXPECT_EQ(attr["activation_mode"].s(), "Relu");
434 found++;
435 }
436 }
437 EXPECT_EQ(found, 2);
438 }
439
440 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
441 ASSERT_EQ(tensors_expected.size(), 1);
442 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
443 ASSERT_EQ(tensors.size(), 1);
444 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
445 }
446 }
447 }
448
TEST_F(MklRemapperTest,FuseMatMulWithBiasAddAndAdd)449 TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
450 using ::tensorflow::ops::Placeholder;
451
452 for (const string& add_op : {"BiasAdd", "AddV2", "Add"}) {
453 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
454
455 auto input_shape = ops::Placeholder::Shape({4, 32});
456 auto input_shape_add = ops::Placeholder::Shape({4, 8});
457 auto filter_shape = ops::Placeholder::Shape({32, 8});
458 auto bias_shape = ops::Placeholder::Shape({8});
459
460 auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
461 auto input_add =
462 Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
463 auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
464 auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
465
466 auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
467 Output bias_add;
468 if (add_op == "BiasAdd")
469 bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
470 else if (add_op == "AddV2")
471 bias_add = ops::AddV2(s.WithOpName("bias_add"), matmul, bias);
472 else if (add_op == "Add")
473 bias_add = ops::Add(s.WithOpName("bias_add"), bias, matmul);
474
475 auto fetch = s.WithOpName("fetch");
476 auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
477
478 ops::Identity(fetch, add);
479
480 auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
481 TensorShape(input_shape.shape_.dim_sizes()));
482 auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
483 TensorShape(input_shape_add.shape_.dim_sizes()));
484 auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
485 TensorShape(filter_shape.shape_.dim_sizes()));
486 auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
487 TensorShape(bias_shape.shape_.dim_sizes()));
488
489 GrapplerItem item;
490 item.fetch = {"fetch"};
491 item.feed = {{"input", input_tensor},
492 {"filter", filter_tensor},
493 {"bias", bias_tensor},
494 {"input_add", input_add_tensor}};
495 TF_CHECK_OK(s.ToGraphDef(&item.graph));
496
497 // Place all nodes on CPU.
498 for (int i = 0; i < item.graph.node_size(); ++i) {
499 item.graph.mutable_node(i)->set_device("/device:CPU:0");
500 }
501
502 Remapper optimizer(RewriterConfig::AGGRESSIVE);
503 GraphDef output;
504 TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
505
506 int found = 0;
507 for (const NodeDef& node : output.node()) {
508 auto fetch_node_name = "add";
509 if (node.name() == fetch_node_name) {
510 EXPECT_EQ("_FusedMatMul", node.op());
511 EXPECT_EQ("input", node.input(0));
512 EXPECT_EQ("filter", node.input(1));
513 EXPECT_EQ(2, node.attr().at("num_args").i());
514 EXPECT_EQ("bias", node.input(2));
515 EXPECT_EQ("input_add", node.input(3));
516
517 const auto fused_ops = node.attr().at("fused_ops").list().s();
518 EXPECT_EQ(2, fused_ops.size());
519 EXPECT_EQ("BiasAdd", fused_ops[0]);
520 EXPECT_EQ("Add", fused_ops[1]);
521 found++;
522 }
523 }
524 EXPECT_EQ(1, found);
525
526 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
527 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
528 EXPECT_EQ(1, tensors_expected.size());
529 EXPECT_EQ(1, tensors.size());
530 test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
531 }
532 }
533
534 class RelpaceAddWithBiasAddTest : public GrapplerTest {
535 public:
536 const string kAddOp = "Add";
537 const string kAddV2Op = "AddV2";
538
539 protected:
540 template <DataType DTYPE>
RelpaceAddWithBiasAddDepthwiseConv2D(const string & add_op)541 void RelpaceAddWithBiasAddDepthwiseConv2D(const string& add_op) {
542 using ::tensorflow::ops::Placeholder;
543
544 for (const string& activation : {"None", "Relu", "Relu6", "Elu"}) {
545 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
546
547 auto input_shape = Placeholder::Shape({8, 32, 32, 3});
548 auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
549 auto bias_shape = Placeholder::Shape({128 * 3});
550
551 auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
552 auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
553 auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
554
555 std::vector<int> strides = {1, 1, 1, 1};
556 auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
557 input, filter, strides, "SAME");
558
559 Output bias_add;
560 if (add_op == kAddV2Op) {
561 bias_add = ops::AddV2(s.WithOpName(add_op), conv, bias);
562 } else {
563 bias_add = ops::Add(s.WithOpName(add_op), bias, conv);
564 }
565
566 ops::Identity fetch = [&]() -> ops::Identity {
567 auto activate = s.WithOpName("activation");
568 auto fetch = s.WithOpName("fetch");
569
570 if (activation == "Relu") {
571 return ops::Identity(fetch, ops::Relu(activate, bias_add));
572 } else if (activation == "Relu6") {
573 return ops::Identity(fetch, ops::Relu6(activate, bias_add));
574 } else if (activation == "Elu") {
575 return ops::Identity(fetch, ops::Elu(activate, bias_add));
576 }
577
578 return ops::Identity(fetch, bias_add);
579 }();
580
581 auto input_t = GenerateRandomTensor<DTYPE>({8, 32, 32, 3});
582 auto filter_t = GenerateRandomTensor<DTYPE>({1, 1, 3, 128});
583 auto bias_t = GenerateRandomTensor<DTYPE>({128 * 3});
584
585 GrapplerItem item;
586 item.fetch = {"fetch"};
587 item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
588 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
589
590 // Place all nodes on CPU.
591 for (int i = 0; i < item.graph.node_size(); ++i) {
592 item.graph.mutable_node(i)->set_device("/device:CPU:0");
593 }
594
595 Remapper optimizer(RewriterConfig::AGGRESSIVE);
596 GraphDef output;
597 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
598
599 int found = 0;
600 for (const NodeDef& node : output.node()) {
601 if (node.name() == "activation") {
602 EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
603 ASSERT_GE(node.input_size(), 3);
604 EXPECT_EQ(node.input(0), "input");
605 EXPECT_EQ(node.input(1), "filter");
606 EXPECT_EQ(node.attr().at("num_args").i(), 1);
607 EXPECT_EQ(node.input(2), "bias");
608
609 const auto fused_ops = node.attr().at("fused_ops").list().s();
610 ASSERT_EQ(fused_ops.size(), 2);
611 EXPECT_EQ(fused_ops[0], "BiasAdd");
612 EXPECT_EQ(fused_ops[1], activation);
613
614 found++;
615 } else if (node.name() == add_op) {
616 EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
617 ASSERT_GE(node.input_size(), 3);
618 EXPECT_EQ(node.input(0), "input");
619 EXPECT_EQ(node.input(1), "filter");
620 EXPECT_EQ(node.attr().at("num_args").i(), 1);
621 EXPECT_EQ(node.input(2), "bias");
622
623 const auto fused_ops = node.attr().at("fused_ops").list().s();
624 ASSERT_EQ(fused_ops.size(), 1);
625 EXPECT_EQ(fused_ops[0], "BiasAdd");
626 found++;
627 }
628 }
629 EXPECT_EQ(found, 1);
630
631 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
632 ASSERT_EQ(tensors_expected.size(), 1);
633 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
634 ASSERT_EQ(tensors.size(), 1);
635
636 if (DTYPE == DT_BFLOAT16)
637 test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
638 else
639 test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
640 }
641 }
642 };
643
644 #define CREATE_REPLACEADDWITHBIASADD_TEST_1(ops, addop, dtype) \
645 TEST_F(RelpaceAddWithBiasAddTest, RelpaceAddWithBiasAdd##ops##_##addop) { \
646 RelpaceAddWithBiasAddDepthwiseConv2D<dtype>(#addop); \
647 }
648 CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, AddV2, DT_FLOAT);
649 CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, Add, DT_FLOAT);
650
651 class FusedMatMulBiasAddAndGeluTest : public GrapplerTest {
652 public:
653 template <DataType DTYPE>
RunTest()654 void RunTest() {
655 using ::tensorflow::ops::Placeholder;
656
657 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
658
659 auto lhs_shape = ops::Placeholder::Shape({8, 32});
660 auto rhs_shape = ops::Placeholder::Shape({32, 64});
661 auto bias_shape = ops::Placeholder::Shape({64});
662
663 auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
664 auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
665 auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
666
667 auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
668 auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
669
670 // Add Gelu approximate with smaller ops
671 auto square_root_one_half =
672 ops::Const(s.WithOpName("square_root_one_half"), {0.707106f}, {});
673 auto bias_add_times_square_root_one_half =
674 ops::Mul(s.WithOpName("bias_add_times_square_root_one_half"), bias_add,
675 square_root_one_half);
676 auto erf =
677 ops::Erf(s.WithOpName("erf"), bias_add_times_square_root_one_half);
678 auto one = ops::Const(s.WithOpName("one"), {1.0f}, {});
679 auto erf_plus_one = ops::AddV2(s.WithOpName("one_plus_erf"), erf, one);
680 auto one_half = ops::Const(s.WithOpName("one_half"), {0.5f}, {});
681 auto erf_plus_one_times_one_half = ops::Mul(
682 s.WithOpName("erf_plus_one_times_one_half"), erf_plus_one, one_half);
683 auto gelu = ops::Mul(s.WithOpName("fusion_output"),
684 erf_plus_one_times_one_half, bias_add);
685 auto fetch = ops::Identity(s.WithOpName("fetch"), gelu);
686
687 auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
688 auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
689 auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
690
691 GrapplerItem item;
692 item.fetch = {"fetch"};
693 item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
694 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
695
696 // Place all nodes on CPU.
697 for (int i = 0; i < item.graph.node_size(); ++i) {
698 item.graph.mutable_node(i)->set_device("/device:CPU:0");
699 }
700
701 Remapper optimizer(RewriterConfig::ON);
702 GraphDef optimized_graph;
703 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
704 int found = 0;
705 for (const NodeDef& node : optimized_graph.node()) {
706 if (node.name() == "fusion_output") {
707 EXPECT_EQ(node.op(), "_FusedMatMul");
708 ASSERT_GE(node.input_size(), 3);
709 EXPECT_EQ(node.input(0), "lhs");
710 EXPECT_EQ(node.input(1), "rhs");
711 EXPECT_EQ(node.input(2), "bias");
712 EXPECT_EQ(node.attr().at("num_args").i(), 1);
713 const auto fused_ops = node.attr().at("fused_ops").list().s();
714 ASSERT_EQ(fused_ops.size(), 2);
715 EXPECT_EQ(fused_ops[0], "BiasAdd");
716 EXPECT_EQ(fused_ops[1], "GeluExact");
717 found++;
718 }
719 }
720 EXPECT_EQ(1, found);
721
722 // Evaluate result without remapper fusion
723 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
724 ASSERT_EQ(tensors_expected.size(), 1);
725
726 auto tensors_evaluated =
727 EvaluateNodes(optimized_graph, item.fetch, item.feed);
728 ASSERT_EQ(tensors_evaluated.size(), 1);
729 test::ExpectClose(tensors_evaluated[0], tensors_expected[0], 1e-6);
730 }
731 };
732
733 // Gelu has two implementations (1) exact and (2) approximate. Exact cannot be
734 // used with bfloat16 numeric since the Erf is not supported in bfloat16 yet.
735 // Here gelu-exact is tested for float32 numeric only. Gelu-approximate test
736 // is added in tensorflow/python/grappler/remapper_test.py, since the pattern is
737 // changed by other optimizers before the remapper optimizer.
TEST_F(FusedMatMulBiasAddAndGeluTest,Float32GeluExact)738 TEST_F(FusedMatMulBiasAddAndGeluTest, Float32GeluExact) { RunTest<DT_FLOAT>(); }
739
740 class MklFusedBatchMatMul : public MklRemapperTest {
741 public:
742 template <typename T>
VerifyFused(bool adjx,bool adjy)743 void VerifyFused(bool adjx, bool adjy) {
744 using ::tensorflow::ops::Placeholder;
745 using normal_generator = Eigen::internal::NormalRandomGenerator<T>;
746
747 int b0 = 2;
748 int b1 = 2;
749 int m = 32;
750 int k = 16;
751 int n = 64;
752
753 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
754
755 auto input_shape =
756 adjx ? TensorShape({b0, b1, k, m}) : TensorShape({b0, b1, m, k});
757 auto weight_shape =
758 adjy ? TensorShape({b0, b1, n, k}) : TensorShape({b0, b1, k, n});
759 auto add_shape = TensorShape({b0, 1, m, n});
760
761 auto input_placeholder_shape = ops::Placeholder::Shape(input_shape);
762 auto weight_placeholder_shape = ops::Placeholder::Shape(weight_shape);
763 auto add_placeholder_shape = ops::Placeholder::Shape(add_shape);
764
765 auto input = Placeholder(s.WithOpName("input"), DataTypeToEnum<T>::v(),
766 input_placeholder_shape);
767 auto weight = Placeholder(s.WithOpName("weight"), DataTypeToEnum<T>::v(),
768 weight_placeholder_shape);
769 auto addend = Placeholder(s.WithOpName("addend"), DataTypeToEnum<T>::v(),
770 add_placeholder_shape);
771
772 auto batchmatmul =
773 ops::BatchMatMulV2(s.WithOpName("batchmatmul"), input, weight,
774 ops::BatchMatMulV2::Attrs().AdjX(adjx).AdjY(adjy));
775 auto scale_const = ops::Const(s.WithOpName("scale_const"), {0.1f});
776 auto scale =
777 ops::Cast(s.WithOpName("scale"), scale_const, DataTypeToEnum<T>::v());
778 auto mul = ops::Multiply(s.WithOpName("mul"), batchmatmul, scale);
779 auto add = ops::AddV2(s.WithOpName("add"), mul, addend);
780 auto fetch = ops::Identity(s.WithOpName("fetch"), add);
781
782 Tensor input_t = Tensor(DataTypeToEnum<T>::v(), input_shape);
783 Tensor weight_t = Tensor(DataTypeToEnum<T>::v(), weight_shape);
784 Tensor add_t = Tensor(DataTypeToEnum<T>::v(), add_shape);
785 input_t.flat<T>() =
786 input_t.flat<T>().template setRandom<normal_generator>();
787 weight_t.flat<T>() =
788 weight_t.flat<T>().template setRandom<normal_generator>();
789 add_t.flat<T>() = add_t.flat<T>().template setRandom<normal_generator>();
790
791 GrapplerItem item;
792 item.fetch = {"fetch"};
793 item.feed = {{"input", input_t}, {"weight", weight_t}, {"addend", add_t}};
794 TF_CHECK_OK(s.ToGraphDef(&item.graph));
795
796 // Place all nodes on CPU.
797 for (int i = 0; i < item.graph.node_size(); ++i) {
798 item.graph.mutable_node(i)->set_device("/device:CPU:0");
799 }
800
801 Remapper optimizer(RewriterConfig::ON);
802 GraphDef output;
803 TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
804
805 int found = 0;
806 for (const NodeDef& node : output.node()) {
807 if (node.name() == "add") {
808 EXPECT_EQ("_MklFusedBatchMatMulV2", node.op());
809 EXPECT_EQ("input", node.input(0));
810 EXPECT_EQ("weight", node.input(1));
811 EXPECT_EQ("scale", node.input(2));
812 EXPECT_EQ("addend", node.input(3));
813 const auto fused_ops = node.attr().at("fused_ops").list().s();
814 EXPECT_EQ(2, fused_ops.size());
815 EXPECT_EQ("Mul", fused_ops[0]);
816 found++;
817 EXPECT_EQ("Add", fused_ops[1]);
818 found++;
819 }
820 }
821 EXPECT_EQ(2, found);
822
823 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
824 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
825 std::is_same<T, float>::value
826 ? test::ExpectClose(tensors_expected[0], tensors[0], 1e-6, 1e-6)
827 : test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, 1e-2);
828 }
829 };
830
TEST_F(MklFusedBatchMatMul,MulAndAdd)831 TEST_F(MklFusedBatchMatMul, MulAndAdd) {
832 for (const auto adjx : {false, true})
833 for (const auto adjy : {false, true}) {
834 this->VerifyFused<float>(adjx, adjy);
835 this->VerifyFused<bfloat16>(adjx, adjy);
836 }
837 }
838
839 class MklRemapperSwishTest : public GrapplerTest {
840 protected:
841 template <DataType DTYPE>
RunTest()842 void RunTest() {
843 using ::tensorflow::ops::Placeholder;
844
845 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
846 auto mul_shape = ops::Placeholder::Shape({64, 64});
847
848 // We will test four sitations:
849 // 1. y = x * sigmoid(x)
850 // 2. y = sigmoid(x) * x
851 // 3. y = sigmoid(x) * sigmoid(sigmoid(x))
852 // 4. y = sigmoid(sigmoid(x)) * sigmoid(x)
853 auto input = Placeholder(s.WithOpName("input"), DTYPE, mul_shape);
854 auto sigmoid1 = ops::Sigmoid(s.WithOpName("sigmoid1"), input);
855 auto sigmoid2 = ops::Sigmoid(s.WithOpName("sigmoid2"), input);
856 auto sigmoid3_1 = ops::Sigmoid(s.WithOpName("sigmoid3_1"), input);
857 auto sigmoid3_2 = ops::Sigmoid(s.WithOpName("sigmoid3_2"), sigmoid3_1);
858 auto sigmoid4_1 = ops::Sigmoid(s.WithOpName("sigmoid4_1"), input);
859 auto sigmoid4_2 = ops::Sigmoid(s.WithOpName("sigmoid4_2"), sigmoid4_1);
860 auto mul1 = ops::Mul(s.WithOpName("mul1"), input, sigmoid1);
861 auto mul2 = ops::Mul(s.WithOpName("mul2"), sigmoid2, input);
862 auto mul3 = ops::Mul(s.WithOpName("mul3"), sigmoid3_1, sigmoid3_2);
863 auto mul4 = ops::Mul(s.WithOpName("mul4"), sigmoid4_2, sigmoid4_1);
864 auto fetch1 = ops::Identity(s.WithOpName("fetch1"), mul1);
865 auto fetch2 = ops::Identity(s.WithOpName("fetch2"), mul2);
866 auto fetch3 = ops::Identity(s.WithOpName("fetch3"), mul3);
867 auto fetch4 = ops::Identity(s.WithOpName("fetch4"), mul4);
868 auto mul_t = GenerateTensorWithSetRandom<DTYPE>({64, 64});
869
870 GrapplerItem item;
871 item.fetch = {"fetch1", "fetch2", "fetch3", "fetch4"};
872 item.feed = {{"input", mul_t}};
873 TF_ASSERT_OK(s.ToGraphDef(&item.graph));
874
875 // Place all nodes on CPU.
876 for (int i = 0; i < item.graph.node_size(); ++i) {
877 item.graph.mutable_node(i)->set_device("/device:CPU:0");
878 }
879
880 Remapper optimizer(RewriterConfig::ON);
881 GraphDef output;
882 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
883
884 int found = 0;
885 for (const NodeDef& node : output.node()) {
886 if (node.name() == "mul1") {
887 EXPECT_EQ(node.op(), "_MklSwish");
888 ASSERT_EQ(node.input_size(), 1);
889 EXPECT_EQ(node.input(0), "input");
890 ++found;
891 }
892 if (node.name() == "mul2") {
893 EXPECT_EQ(node.op(), "_MklSwish");
894 ASSERT_EQ(node.input_size(), 1);
895 EXPECT_EQ(node.input(0), "input");
896 ++found;
897 }
898 // mul3 won't be replaced by swish
899 // Coz of the limitation of patternMatcher with commutative op
900 if (node.name() == "mul4") {
901 EXPECT_EQ(node.op(), "_MklSwish");
902 ASSERT_EQ(node.input_size(), 1);
903 EXPECT_EQ(node.input(0), "sigmoid4_1");
904 ++found;
905 }
906 }
907 EXPECT_EQ(found, 3);
908
909 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
910 ASSERT_EQ(tensors_expected.size(), 4);
911 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
912 ASSERT_EQ(tensors.size(), 4);
913 float atol = 1e-6, rtol = 1e-6;
914 if (DTYPE == DT_BFLOAT16) {
915 atol = 1e-2;
916 rtol = 1e-2;
917 }
918 test::ExpectClose(tensors[0], tensors_expected[0], atol, rtol);
919 test::ExpectClose(tensors[1], tensors_expected[1], atol, rtol);
920 test::ExpectClose(tensors[2], tensors_expected[2], atol, rtol);
921 test::ExpectClose(tensors[3], tensors_expected[3], atol, rtol);
922 }
923 };
924
TEST_F(MklRemapperSwishTest,F32)925 TEST_F(MklRemapperSwishTest, F32) { RunTest<DT_FLOAT>(); }
TEST_F(MklRemapperSwishTest,BF16)926 TEST_F(MklRemapperSwishTest, BF16) { RunTest<DT_BFLOAT16>(); }
927
928 } // namespace grappler
929 } // namespace tensorflow
930 #endif // INTEL_MKL && ENABLE_MKL
931