1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18 #include <string>
19
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/client/global_data.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/compiler/xla/tests/test_utils.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 class CompilationCacheTest : public ClientLibraryTestBase {
40 public:
ExecuteComputationR0F32(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,float expected_result,bool expect_cache_hit)41 void ExecuteComputationR0F32(const XlaComputation& computation,
42 absl::Span<GlobalData* const> arguments,
43 float expected_result, bool expect_cache_hit) {
44 ExecutionProfile execution_profile;
45 Literal result =
46 client_
47 ->ExecuteAndTransfer(computation, arguments,
48 /*execution_options=*/&execution_options_,
49 &execution_profile)
50 .value();
51 EXPECT_TRUE(LiteralTestUtil::Near(
52 LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
53 EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
54 }
55
ExecuteComputationR2F32(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,std::initializer_list<std::initializer_list<float>> expected_result,bool expect_cache_hit)56 void ExecuteComputationR2F32(
57 const XlaComputation& computation,
58 absl::Span<GlobalData* const> arguments,
59 std::initializer_list<std::initializer_list<float>> expected_result,
60 bool expect_cache_hit) {
61 ExecutionProfile execution_profile;
62 auto data_handle = client_
63 ->Execute(computation, arguments,
64 &execution_options_, &execution_profile)
65 .value();
66 Literal result = client_->Transfer(*data_handle).value();
67 EXPECT_TRUE(LiteralTestUtil::Near(
68 LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
69 EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
70 }
71
72 ErrorSpec error_spec_{0.0001};
73 };
74
75 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_ComputationCalledMultipleTimes)76 XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
77 XlaBuilder builder(TestName());
78 Neg(ConstantR0<float>(&builder, 42.0));
79 XlaComputation computation = builder.Build().value();
80
81 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
82 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
83 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
84 }
85
86 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_ComputationCalledWithDifferentParameters)87 XLA_TEST_F(CompilationCacheTest,
88 DISABLED_ComputationCalledWithDifferentParameters) {
89 std::unique_ptr<GlobalData> data_42 =
90 client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f)).value();
91 std::unique_ptr<GlobalData> data_123 =
92 client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f)).value();
93 std::unique_ptr<GlobalData> data_456 =
94 client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f)).value();
95
96 XlaBuilder builder(TestName());
97 Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param"));
98 XlaComputation computation = builder.Build().value();
99
100 ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
101 /*expect_cache_hit=*/false);
102 ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
103 /*expect_cache_hit=*/true);
104 ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
105 /*expect_cache_hit=*/true);
106 ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
107 /*expect_cache_hit=*/true);
108 }
109
110 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_MultipleComputations)111 XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) {
112 XlaBuilder builder_neg(TestName() + "_neg");
113 Neg(ConstantR0<float>(&builder_neg, 42.0));
114 XlaComputation computation_neg = builder_neg.Build().value();
115
116 XlaBuilder builder_exp(TestName() + "_exp");
117 Exp(ConstantR0<float>(&builder_exp, 1.0));
118 XlaComputation computation_exp = builder_exp.Build().value();
119
120 XlaBuilder builder_add(TestName() + "_add");
121 Add(ConstantR0<float>(&builder_add, 2.0),
122 ConstantR0<float>(&builder_add, 3.0));
123 XlaComputation computation_add = builder_add.Build().value();
124
125 ExecuteComputationR0F32(computation_neg, {}, -42.0,
126 /*expect_cache_hit=*/false);
127 ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
128 /*expect_cache_hit=*/false);
129 ExecuteComputationR0F32(computation_add, {}, 5.0,
130 /*expect_cache_hit=*/false);
131 ExecuteComputationR0F32(computation_neg, {}, -42.0,
132 /*expect_cache_hit=*/true);
133 }
134
135 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_DifferentParameterLayouts)136 XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
137 // Create two GlobalData arrays with the same shape but different
138 // layouts. Use these arrays as parameters to a simple computation. If the
139 // layout of the array changes then computation should be recompiled (cache
140 // miss).
141 auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
142 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
143 auto rowmaj_handle = client_->TransferToServer(rowmaj_array).value();
144
145 auto colmaj_array = LiteralUtil::CreateR2WithLayout(
146 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
147 auto colmaj_handle = client_->TransferToServer(colmaj_array).value();
148
149 XlaBuilder builder(TestName());
150 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
151 XlaComputation computation = builder.Build().value();
152
153 ExecuteComputationR2F32(computation, {colmaj_handle.get()},
154 {{1.0f, 2.0f}, {3.0f, 4.0f}},
155 /*expect_cache_hit=*/false);
156 ExecuteComputationR2F32(computation, {colmaj_handle.get()},
157 {{1.0f, 2.0f}, {3.0f, 4.0f}},
158 /*expect_cache_hit=*/true);
159 ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
160 {{1.0f, 2.0f}, {3.0f, 4.0f}},
161 /*expect_cache_hit=*/false);
162 ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
163 {{1.0f, 2.0f}, {3.0f, 4.0f}},
164 /*expect_cache_hit=*/true);
165 ExecuteComputationR2F32(computation, {colmaj_handle.get()},
166 {{1.0f, 2.0f}, {3.0f, 4.0f}},
167 /*expect_cache_hit=*/true);
168 }
169
170 } // namespace
171 } // namespace xla
172