1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/cc/framework/ops.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/array_ops.h"
27 #include "tensorflow/cc/ops/const_op.h"
28 #include "tensorflow/cc/ops/nn_ops.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_testutil.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/public/session.h"
41 #include "tensorflow/core/public/session_options.h"
42
43 namespace tensorflow {
44 namespace {
GetTestDevice(Session * session,string * test_device)45 Status GetTestDevice(Session* session, string* test_device) {
46 std::vector<DeviceAttributes> devices;
47 TF_RETURN_IF_ERROR(session->ListDevices(&devices));
48
49 bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
50 return device.device_type() == "CPU";
51 });
52
53 bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
54 return device.device_type() == "GPU";
55 });
56
57 if (!found_gpu && !found_cpu) {
58 return errors::Internal("Expected at least one CPU or GPU!");
59 }
60
61 *test_device = found_gpu ? "GPU" : "CPU";
62 VLOG(2) << "Using test device " << *test_device;
63 return OkStatus();
64 }
65
FillZeros(Tensor * tensor)66 void FillZeros(Tensor* tensor) {
67 auto flat = tensor->flat<float>();
68 for (int i = 0; i < flat.size(); i++) {
69 flat.data()[i] = 0.0f;
70 }
71 }
72
73 // This tests check that the implementation outputs from FusedBatchnorm
74 // training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA
75 // lowering.
76 //
77 // If this test starts failing then it doesn't indicate that TF/cudnn have
78 // violated their contract, but it indicates that we need to update the TF/XLA
79 // lowering for FusedBatchnorm training to match the new implementation defined
80 // behavior.
TEST(FusedBatchnormReserveSpaceTest,Test)81 TEST(FusedBatchnormReserveSpaceTest, Test) {
82 using ::tensorflow::ops::Const;
83 using ::tensorflow::ops::FusedBatchNorm;
84
85 std::unique_ptr<tensorflow::Session> session(
86 tensorflow::NewSession(tensorflow::SessionOptions{}));
87
88 string test_device;
89 TF_ASSERT_OK(GetTestDevice(session.get(), &test_device));
90
91 Scope root = tensorflow::Scope::NewRootScope();
92 Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
93
94 Tensor scale_data(DT_FLOAT, TensorShape({10}));
95 FillZeros(&scale_data);
96 Output scale =
97 Const(root.WithOpName("scale"), Input::Initializer(scale_data));
98
99 Tensor offset_data(DT_FLOAT, TensorShape({10}));
100 FillZeros(&offset_data);
101 Output offset =
102 Const(root.WithOpName("offset"), Input::Initializer(offset_data));
103
104 Tensor mean_data(DT_FLOAT, TensorShape({0}));
105 Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
106
107 Tensor variance_data(DT_FLOAT, TensorShape({0}));
108 Output variance =
109 Const(root.WithOpName("variance"), Input::Initializer(variance_data));
110
111 string tf_device = absl::StrCat("/device:", test_device, ":0");
112 string xla_device = absl::StrCat("/device:XLA_", test_device, ":0");
113
114 FusedBatchNorm fused_batch_norm_tf(
115 root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input,
116 scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
117 FusedBatchNorm fused_batch_norm_xla(
118 root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input,
119 scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
120
121 tensorflow::GraphDef graph;
122 TF_ASSERT_OK(root.ToGraphDef(&graph));
123
124 TF_ASSERT_OK(session->Create(graph));
125
126 Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10}));
127 auto flat_input = input_data.flat<float>();
128 for (int i = 0; i < flat_input.size(); i++) {
129 flat_input.data()[i] = (i - 5) / 1000.0f;
130 }
131
132 std::vector<Tensor> results;
133 TF_ASSERT_OK(session->Run({{"input", input_data}},
134 {fused_batch_norm_tf.reserve_space_1.name(),
135 fused_batch_norm_xla.reserve_space_1.name(),
136 fused_batch_norm_tf.reserve_space_2.name(),
137 fused_batch_norm_xla.reserve_space_2.name()},
138 {}, &results));
139
140 test::ExpectClose(results[0], results[1], /*atol=*/1e-4);
141 test::ExpectClose(results[2], results[3], /*atol=*/1e-4);
142 }
143
__anon4d8454e20402null144 static bool Initialized = [] {
145 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
146 return true;
147 }();
148
149 } // namespace
150 } // namespace tensorflow
151