xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/compilation_cache_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 #include <string>
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/client/global_data.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 #include "tensorflow/compiler/xla/tests/test_utils.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 class CompilationCacheTest : public ClientLibraryTestBase {
40  public:
ExecuteComputationR0F32(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,float expected_result,bool expect_cache_hit)41   void ExecuteComputationR0F32(const XlaComputation& computation,
42                                absl::Span<GlobalData* const> arguments,
43                                float expected_result, bool expect_cache_hit) {
44     ExecutionProfile execution_profile;
45     Literal result =
46         client_
47             ->ExecuteAndTransfer(computation, arguments,
48                                  /*execution_options=*/&execution_options_,
49                                  &execution_profile)
50             .value();
51     EXPECT_TRUE(LiteralTestUtil::Near(
52         LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
53     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
54   }
55 
ExecuteComputationR2F32(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,std::initializer_list<std::initializer_list<float>> expected_result,bool expect_cache_hit)56   void ExecuteComputationR2F32(
57       const XlaComputation& computation,
58       absl::Span<GlobalData* const> arguments,
59       std::initializer_list<std::initializer_list<float>> expected_result,
60       bool expect_cache_hit) {
61     ExecutionProfile execution_profile;
62     auto data_handle = client_
63                            ->Execute(computation, arguments,
64                                      &execution_options_, &execution_profile)
65                            .value();
66     Literal result = client_->Transfer(*data_handle).value();
67     EXPECT_TRUE(LiteralTestUtil::Near(
68         LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
69     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
70   }
71 
72   ErrorSpec error_spec_{0.0001};
73 };
74 
75 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_ComputationCalledMultipleTimes)76 XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
77   XlaBuilder builder(TestName());
78   Neg(ConstantR0<float>(&builder, 42.0));
79   XlaComputation computation = builder.Build().value();
80 
81   ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
82   ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
83   ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
84 }
85 
86 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_ComputationCalledWithDifferentParameters)87 XLA_TEST_F(CompilationCacheTest,
88            DISABLED_ComputationCalledWithDifferentParameters) {
89   std::unique_ptr<GlobalData> data_42 =
90       client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f)).value();
91   std::unique_ptr<GlobalData> data_123 =
92       client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f)).value();
93   std::unique_ptr<GlobalData> data_456 =
94       client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f)).value();
95 
96   XlaBuilder builder(TestName());
97   Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param"));
98   XlaComputation computation = builder.Build().value();
99 
100   ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
101                           /*expect_cache_hit=*/false);
102   ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
103                           /*expect_cache_hit=*/true);
104   ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
105                           /*expect_cache_hit=*/true);
106   ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
107                           /*expect_cache_hit=*/true);
108 }
109 
110 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_MultipleComputations)111 XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) {
112   XlaBuilder builder_neg(TestName() + "_neg");
113   Neg(ConstantR0<float>(&builder_neg, 42.0));
114   XlaComputation computation_neg = builder_neg.Build().value();
115 
116   XlaBuilder builder_exp(TestName() + "_exp");
117   Exp(ConstantR0<float>(&builder_exp, 1.0));
118   XlaComputation computation_exp = builder_exp.Build().value();
119 
120   XlaBuilder builder_add(TestName() + "_add");
121   Add(ConstantR0<float>(&builder_add, 2.0),
122       ConstantR0<float>(&builder_add, 3.0));
123   XlaComputation computation_add = builder_add.Build().value();
124 
125   ExecuteComputationR0F32(computation_neg, {}, -42.0,
126                           /*expect_cache_hit=*/false);
127   ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
128                           /*expect_cache_hit=*/false);
129   ExecuteComputationR0F32(computation_add, {}, 5.0,
130                           /*expect_cache_hit=*/false);
131   ExecuteComputationR0F32(computation_neg, {}, -42.0,
132                           /*expect_cache_hit=*/true);
133 }
134 
135 // TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest,DISABLED_DifferentParameterLayouts)136 XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
137   // Create two GlobalData arrays with the same shape but different
138   // layouts. Use these arrays as parameters to a simple computation. If the
139   // layout of the array changes then computation should be recompiled (cache
140   // miss).
141   auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
142       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
143   auto rowmaj_handle = client_->TransferToServer(rowmaj_array).value();
144 
145   auto colmaj_array = LiteralUtil::CreateR2WithLayout(
146       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
147   auto colmaj_handle = client_->TransferToServer(colmaj_array).value();
148 
149   XlaBuilder builder(TestName());
150   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
151   XlaComputation computation = builder.Build().value();
152 
153   ExecuteComputationR2F32(computation, {colmaj_handle.get()},
154                           {{1.0f, 2.0f}, {3.0f, 4.0f}},
155                           /*expect_cache_hit=*/false);
156   ExecuteComputationR2F32(computation, {colmaj_handle.get()},
157                           {{1.0f, 2.0f}, {3.0f, 4.0f}},
158                           /*expect_cache_hit=*/true);
159   ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
160                           {{1.0f, 2.0f}, {3.0f, 4.0f}},
161                           /*expect_cache_hit=*/false);
162   ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
163                           {{1.0f, 2.0f}, {3.0f, 4.0f}},
164                           /*expect_cache_hit=*/true);
165   ExecuteComputationR2F32(computation, {colmaj_handle.get()},
166                           {{1.0f, 2.0f}, {3.0f, 4.0f}},
167                           /*expect_cache_hit=*/true);
168 }
169 
170 }  // namespace
171 }  // namespace xla
172