xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/cc/framework/ops.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/array_ops.h"
27 #include "tensorflow/cc/ops/const_op.h"
28 #include "tensorflow/cc/ops/nn_ops.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_testutil.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/public/session.h"
41 #include "tensorflow/core/public/session_options.h"
42 
43 namespace tensorflow {
44 namespace {
GetTestDevice(Session * session,string * test_device)45 Status GetTestDevice(Session* session, string* test_device) {
46   std::vector<DeviceAttributes> devices;
47   TF_RETURN_IF_ERROR(session->ListDevices(&devices));
48 
49   bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
50     return device.device_type() == "CPU";
51   });
52 
53   bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
54     return device.device_type() == "GPU";
55   });
56 
57   if (!found_gpu && !found_cpu) {
58     return errors::Internal("Expected at least one CPU or GPU!");
59   }
60 
61   *test_device = found_gpu ? "GPU" : "CPU";
62   VLOG(2) << "Using test device " << *test_device;
63   return OkStatus();
64 }
65 
FillZeros(Tensor * tensor)66 void FillZeros(Tensor* tensor) {
67   auto flat = tensor->flat<float>();
68   for (int i = 0; i < flat.size(); i++) {
69     flat.data()[i] = 0.0f;
70   }
71 }
72 
73 // This tests check that the implementation outputs from FusedBatchnorm
74 // training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA
75 // lowering.
76 //
77 // If this test starts failing then it doesn't indicate that TF/cudnn have
78 // violated their contract, but it indicates that we need to update the TF/XLA
79 // lowering for FusedBatchnorm training to match the new implementation defined
80 // behavior.
TEST(FusedBatchnormReserveSpaceTest,Test)81 TEST(FusedBatchnormReserveSpaceTest, Test) {
82   using ::tensorflow::ops::Const;
83   using ::tensorflow::ops::FusedBatchNorm;
84 
85   std::unique_ptr<tensorflow::Session> session(
86       tensorflow::NewSession(tensorflow::SessionOptions{}));
87 
88   string test_device;
89   TF_ASSERT_OK(GetTestDevice(session.get(), &test_device));
90 
91   Scope root = tensorflow::Scope::NewRootScope();
92   Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
93 
94   Tensor scale_data(DT_FLOAT, TensorShape({10}));
95   FillZeros(&scale_data);
96   Output scale =
97       Const(root.WithOpName("scale"), Input::Initializer(scale_data));
98 
99   Tensor offset_data(DT_FLOAT, TensorShape({10}));
100   FillZeros(&offset_data);
101   Output offset =
102       Const(root.WithOpName("offset"), Input::Initializer(offset_data));
103 
104   Tensor mean_data(DT_FLOAT, TensorShape({0}));
105   Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
106 
107   Tensor variance_data(DT_FLOAT, TensorShape({0}));
108   Output variance =
109       Const(root.WithOpName("variance"), Input::Initializer(variance_data));
110 
111   string tf_device = absl::StrCat("/device:", test_device, ":0");
112   string xla_device = absl::StrCat("/device:XLA_", test_device, ":0");
113 
114   FusedBatchNorm fused_batch_norm_tf(
115       root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input,
116       scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
117   FusedBatchNorm fused_batch_norm_xla(
118       root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input,
119       scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
120 
121   tensorflow::GraphDef graph;
122   TF_ASSERT_OK(root.ToGraphDef(&graph));
123 
124   TF_ASSERT_OK(session->Create(graph));
125 
126   Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10}));
127   auto flat_input = input_data.flat<float>();
128   for (int i = 0; i < flat_input.size(); i++) {
129     flat_input.data()[i] = (i - 5) / 1000.0f;
130   }
131 
132   std::vector<Tensor> results;
133   TF_ASSERT_OK(session->Run({{"input", input_data}},
134                             {fused_batch_norm_tf.reserve_space_1.name(),
135                              fused_batch_norm_xla.reserve_space_1.name(),
136                              fused_batch_norm_tf.reserve_space_2.name(),
137                              fused_batch_norm_xla.reserve_space_2.name()},
138                             {}, &results));
139 
140   test::ExpectClose(results[0], results[1], /*atol=*/1e-4);
141   test::ExpectClose(results[2], results[3], /*atol=*/1e-4);
142 }
143 
__anon4d8454e20402null144 static bool Initialized = [] {
145   tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
146   return true;
147 }();
148 
149 }  // namespace
150 }  // namespace tensorflow
151