xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/local_service.h"
29 #include "tensorflow/compiler/xla/service/service.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 namespace {
39 
40 constexpr int64_t kPointerSize = 8;
41 
ShapeSize(const Shape & shape)42 int64_t ShapeSize(const Shape& shape) {
43   return ShapeUtil::ByteSizeOf(shape, kPointerSize);
44 }
45 
46 // This test suite tests the HLO cost analysis by first building a computation
47 // using the client computation builder and running the HloCostAnalysis that
48 // returns the number of floating point and transcendental operations in the
49 // graph. We test both individual HLO operations as well as a mixed graph.
50 class HloCostAnalysisTest : public ::testing::Test {
51  protected:
HloCostAnalysisTest()52   HloCostAnalysisTest()
53       : client_(ClientLibrary::LocalClientOrDie()),
54         // Accessing service instance is required for the unit tests to enable
55         // whitebox accesses to the user computation built from the client,
56         // as shown in the BuildHloGraph functions below.
57         service_(static_cast<Service*>(ClientLibrary::GetXlaService(
58             static_cast<LocalClient*>(client_)->platform()))) {
59     // Create a computation for a unary user function: x => exp(x + 0.5)
60     {
61       XlaBuilder builder("add_and_exp");
62       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
63       auto half = ConstantR0<float>(&builder, 0.5);
64       Exp(Add(x, half));
65       auto computation_status = builder.Build();
66       TF_CHECK_OK(computation_status.status());
67       add_and_exp_ = std::move(computation_status).value();
68     }
69 
70     // Create a computation for a binary user function: (x, y) => x + y
71     {
72       XlaBuilder builder("add");
73       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
74       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
75       Add(x, y);
76       auto computation_status = builder.Build();
77       TF_CHECK_OK(computation_status.status());
78       add_ = std::move(computation_status).value();
79     }
80 
81     // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
82     {
83       XlaBuilder builder("sigmoid");
84       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
85       auto one = ConstantR0<float>(&builder, 1.0);
86       Div(one, Add(one, Exp(Neg(x))));
87       auto computation_status = builder.Build();
88       TF_CHECK_OK(computation_status.status());
89       sigmoid_ = std::move(computation_status).value();
90     }
91 
92     // Create a computation for a binary max function: (x, y) => max (x, y)
93     {
94       XlaBuilder builder("max");
95       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
96       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
97       Max(x, y);
98       auto computation_status = builder.Build();
99       TF_CHECK_OK(computation_status.status());
100       max_ = std::move(computation_status).value();
101     }
102 
103     // Create a computation for a binary GT function: (x, y) => x > y
104     {
105       XlaBuilder builder("gt");
106       auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
107       auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
108       Gt(x, y);
109       auto computation_status = builder.Build();
110       TF_CHECK_OK(computation_status.status());
111       gt_ = std::move(computation_status).value();
112     }
113   }
114 
115   // Build HLO graph from the given builder and return the HLO module.
BuildHloGraph(XlaBuilder * builder)116   std::unique_ptr<HloModule> BuildHloGraph(XlaBuilder* builder) {
117     auto computation_status = builder->Build();
118     TF_CHECK_OK(computation_status.status());
119     auto computation = std::move(computation_status).value();
120     auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
121                                                          DebugOptions())
122                       .value();
123     return HloModule::CreateFromProto(computation.proto(), config).value();
124   }
125 
126   Client* client_;
127   Service* service_;
128 
129   // User computations used for higher order operations (e.g., Map, Reduce).
130   XlaComputation add_;
131   XlaComputation add_and_exp_;
132   XlaComputation sigmoid_;
133   XlaComputation max_;
134   XlaComputation gt_;
135 };
136 
137 using HloCostAnalysisHloTest = HloTestBase;
138 
TEST_F(HloCostAnalysisTest,MatrixMultiply)139 TEST_F(HloCostAnalysisTest, MatrixMultiply) {
140   XlaBuilder builder("matrix_multiply");
141   auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
142   auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
143   Dot(lhs, rhs);
144 
145   // Run HLO cost analysis.
146   auto hlo_module = BuildHloGraph(&builder);
147   HloCostAnalysis analysis(ShapeSize);
148   ASSERT_IS_OK(
149       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
150 
151   // Check the number of computations returned from the analysis (1500 FMAs).
152   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
153 
154   EXPECT_EQ(analysis.transcendental_count(), 0);
155 
156   // Bytes accessed is sum of inputs and output.
157   EXPECT_EQ(analysis.bytes_accessed(),
158             sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
159 
160   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
161   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
162   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
163   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
164 }
165 
TEST_F(HloCostAnalysisTest,DotGeneral)166 TEST_F(HloCostAnalysisTest, DotGeneral) {
167   XlaBuilder builder("matrix_multiply");
168   auto lhs =
169       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
170   auto rhs =
171       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
172   DotDimensionNumbers dnums;
173   dnums.add_lhs_contracting_dimensions(1);
174   dnums.add_lhs_contracting_dimensions(2);
175   dnums.add_rhs_contracting_dimensions(0);
176   dnums.add_rhs_contracting_dimensions(1);
177   DotGeneral(lhs, rhs, dnums);
178 
179   // Run HLO cost analysis.
180   auto hlo_module = BuildHloGraph(&builder);
181   HloCostAnalysis analysis(ShapeSize);
182   ASSERT_IS_OK(
183       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
184 
185   // Check the number of computations returned from the analysis (1500 FMAs).
186   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
187 
188   EXPECT_EQ(analysis.transcendental_count(), 0);
189 
190   // Bytes accessed is sum of inputs and output.
191   EXPECT_EQ(analysis.bytes_accessed(),
192             sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30));
193 
194   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
195   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
196             sizeof(float) * 10 * 5 * 5);
197   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
198             sizeof(float) * 5 * 5 * 30);
199   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
200 }
201 
TEST_F(HloCostAnalysisTest,DotGeneral2)202 TEST_F(HloCostAnalysisTest, DotGeneral2) {
203   XlaBuilder builder("matrix_multiply");
204   auto lhs =
205       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
206   auto rhs =
207       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
208   DotDimensionNumbers dnums;
209   dnums.add_lhs_contracting_dimensions(1);
210   dnums.add_lhs_batch_dimensions(2);
211   dnums.add_rhs_contracting_dimensions(0);
212   dnums.add_rhs_batch_dimensions(1);
213   DotGeneral(lhs, rhs, dnums);
214 
215   // Run HLO cost analysis.
216   auto hlo_module = BuildHloGraph(&builder);
217   HloCostAnalysis analysis(ShapeSize);
218   ASSERT_IS_OK(
219       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
220 
221   // Check the number of computations returned from the analysis (1500 FMAs).
222   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
223 
224   EXPECT_EQ(analysis.transcendental_count(), 0);
225 
226   // Bytes accessed is sum of inputs and output.
227   EXPECT_EQ(analysis.bytes_accessed(),
228             sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30));
229 
230   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
231   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
232             sizeof(float) * 10 * 5 * 5);
233   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
234             sizeof(float) * 5 * 5 * 30);
235   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 5 * 10 * 30);
236 }
237 
TEST_F(HloCostAnalysisTest,DotGeneral3)238 TEST_F(HloCostAnalysisTest, DotGeneral3) {
239   XlaBuilder builder("matrix_multiply");
240   auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
241   auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
242   DotDimensionNumbers dnums;
243   DotGeneral(lhs, rhs, dnums);
244 
245   // Run HLO cost analysis.
246   auto hlo_module = BuildHloGraph(&builder);
247   HloCostAnalysis analysis(ShapeSize);
248   ASSERT_IS_OK(
249       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
250 
251   // Check the number of computations returned from the analysis (1500 FMAs).
252   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
253 
254   EXPECT_EQ(analysis.transcendental_count(), 0);
255 
256   // Bytes accessed is sum of inputs and output.
257   EXPECT_EQ(analysis.bytes_accessed(),
258             sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30));
259 
260   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
261   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
262   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
263   EXPECT_EQ(analysis.output_bytes_accessed(*root),
264             sizeof(float) * 5 * 5 * 10 * 30);
265 }
266 
TEST_F(HloCostAnalysisTest,Map)267 TEST_F(HloCostAnalysisTest, Map) {
268   XlaBuilder builder("map");
269   auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in");
270   Map(&builder, {input}, add_and_exp_, {0});
271 
272   // Run HLO cost analysis.
273   auto hlo_module = BuildHloGraph(&builder);
274   HloCostAnalysis analysis(ShapeSize);
275   ASSERT_IS_OK(
276       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
277 
278   // add contributes to 10 flops and exp contributes to 10 transcendental ops.
279   EXPECT_EQ(analysis.flop_count(), 10);
280   EXPECT_EQ(analysis.transcendental_count(), 10);
281   EXPECT_EQ(analysis.bytes_accessed(), 80);
282 
283   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
284   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10);
285   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
286 }
287 
TEST_F(HloCostAnalysisTest,Convolution)288 TEST_F(HloCostAnalysisTest, Convolution) {
289   XlaBuilder builder("convolution");
290   auto input = Parameter(
291       &builder, 0,
292       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
293                                  /*x_dim=*/20}),
294       "input");
295   auto kernel = Parameter(
296       &builder, 1,
297       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
298                                  /*x_dim=*/3}),
299       "kernel");
300   Conv(input, kernel, {1, 1}, Padding::kValid);
301 
302   // Run HLO cost analysis.
303   auto hlo_module = BuildHloGraph(&builder);
304   HloCostAnalysis analysis(ShapeSize);
305   ASSERT_IS_OK(
306       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
307 
308   // Output shape is [1x1x8x18] and each output element requires (3x3)
309   // FMAs and one FMA is 2 flops.
310   EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
311 
312   // Bytes accessed is sum of inputs and output.
313   EXPECT_EQ(analysis.bytes_accessed(),
314             sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
315 
316   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
317   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
318   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 3 * 3);
319   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 8 * 18);
320 }
321 
TEST_F(HloCostAnalysisTest,ConvolutionSame)322 TEST_F(HloCostAnalysisTest, ConvolutionSame) {
323   XlaBuilder builder("convolution_same");
324   const int iw = 3;
325   const int ih = 3;
326   const int kw = 3;
327   const int kh = 3;
328   const int ow = iw;
329   const int oh = ih;
330   const int sx = 1;
331   const int sy = 1;
332   auto input = Parameter(
333       &builder, 0,
334       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/ih,
335                                  /*x_dim=*/iw}),
336       "input");
337   auto kernel = Parameter(
338       &builder, 1,
339       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kh,
340                                  /*x_dim=*/kw}),
341       "kernel");
342   Conv(input, kernel, {sx, sy}, Padding::kSame);
343 
344   // Run HLO cost analysis.
345   auto hlo_module = BuildHloGraph(&builder);
346   HloCostAnalysis analysis(ShapeSize);
347   ASSERT_IS_OK(
348       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
349 
350   // Output shape is [1x1x3x3] and each output element requires (3x3)
351   // FMAs and one FMA is 2 flops.
352   // NOTE: This formula only works for the hard-coded dimensions for now.
353   EXPECT_EQ(analysis.flop_count(), 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4));
354 
355   // Bytes accessed is sum of inputs and output.
356   EXPECT_EQ(analysis.bytes_accessed(),
357             sizeof(float) * (iw * ih + kw * kh + ow * oh));
358 
359   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
360   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * iw * ih);
361   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * kw * kh);
362   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * ow * oh);
363 }
364 
TEST_F(HloCostAnalysisTest,ConvolutionExtreme)365 TEST_F(HloCostAnalysisTest, ConvolutionExtreme) {
366   XlaBuilder builder("convolution");
367   constexpr int64_t kLarge = 512 * 1024;
368   auto input = Parameter(
369       &builder, 0,
370       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
371       "input");
372   auto kernel = Parameter(
373       &builder, 1,
374       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
375       "kernel");
376   ConvGeneralDilated(input, kernel, {kLarge - 1}, {{0, 0}}, {kLarge}, {1},
377                      XlaBuilder::CreateDefaultConvDimensionNumbers(1));
378 
379   // Run HLO cost analysis.
380   auto hlo_module = BuildHloGraph(&builder);
381   HloCostAnalysis analysis(ShapeSize);
382   ASSERT_IS_OK(
383       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
384 
385   EXPECT_EQ(analysis.flop_count(), 2 * kLarge);
386 }
387 
TEST_F(HloCostAnalysisTest,ConvolutionExtreme2)388 TEST_F(HloCostAnalysisTest, ConvolutionExtreme2) {
389   XlaBuilder builder("convolution");
390   constexpr int64_t kLarge = 512 * 1024;
391   auto input = Parameter(
392       &builder, 0,
393       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/1}),
394       "input");
395   auto kernel = Parameter(
396       &builder, 1,
397       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
398       "kernel");
399   ConvGeneralDilated(input, kernel, {1}, {{kLarge - 1, kLarge - 1}}, {1}, {1},
400                      XlaBuilder::CreateDefaultConvDimensionNumbers(1));
401 
402   // Run HLO cost analysis.
403   auto hlo_module = BuildHloGraph(&builder);
404   HloCostAnalysis analysis(ShapeSize);
405   ASSERT_IS_OK(
406       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
407 
408   EXPECT_EQ(analysis.flop_count(), 2 * kLarge);
409 }
410 
TEST_F(HloCostAnalysisTest,ConvolutionWithFeatureGroup)411 TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
412   XlaBuilder builder("convolution");
413   auto input = Parameter(
414       &builder, 0,
415       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
416                                  /*x_dim=*/20}),
417       "input");
418   auto kernel = Parameter(
419       &builder, 1,
420       ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
421                                  /*x_dim=*/3}),
422       "kernel");
423   Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
424 
425   // Run HLO cost analysis.
426   auto hlo_module = BuildHloGraph(&builder);
427   HloCostAnalysis analysis(ShapeSize);
428   ASSERT_IS_OK(
429       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
430 
431   // Output shape is [1x120x8x18] and each output element requires (3x3)
432   // FMAs and one FMA is 2 flops.
433   EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
434 
435   // Bytes accessed is sum of inputs and output.
436   EXPECT_EQ(analysis.bytes_accessed(),
437             sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
438 
439   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
440   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
441             sizeof(float) * 120 * 10 * 20);
442   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
443             sizeof(float) * 120 * 3 * 3);
444   EXPECT_EQ(analysis.output_bytes_accessed(*root),
445             sizeof(float) * 120 * 8 * 18);
446 }
447 
TEST_F(HloCostAnalysisTest,Reduce)448 TEST_F(HloCostAnalysisTest, Reduce) {
449   XlaBuilder builder("reduce");
450   auto input =
451       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
452   Reduce(input, ConstantR0<float>(&builder, 0.0f), add_, {1});
453 
454   // Run HLO cost analysis.
455   auto hlo_module = BuildHloGraph(&builder);
456   HloCostAnalysis analysis(ShapeSize);
457   ASSERT_IS_OK(
458       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
459 
460   // Subtracting the output size from the input size gives the number of
461   // reduction operations performed.
462   EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
463 
464   EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 10));
465 
466   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
467   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
468   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
469   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
470 }
471 
TEST_F(HloCostAnalysisTest,ReduceWindow)472 TEST_F(HloCostAnalysisTest, ReduceWindow) {
473   XlaBuilder builder("reduce_window");
474   auto input =
475       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
476   ReduceWindow(input, ConstantR0<float>(&builder, 0), add_, {4, 5}, {4, 5},
477                Padding::kValid);
478 
479   // Run HLO cost analysis.
480   auto hlo_module = BuildHloGraph(&builder);
481   HloCostAnalysis analysis(ShapeSize);
482   ASSERT_IS_OK(
483       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
484 
485   // Each of [2x4] output elements are generated from reducing [4x5] elements.
486   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
487 
488   EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 2 * 4));
489 
490   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
491   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
492   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
493   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
494 }
495 
TEST_F(HloCostAnalysisTest,ReduceWindowVariadic)496 TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) {
497   XlaBuilder builder("reduce_window_variadic");
498   auto elem_shape = ShapeUtil::MakeShape(F32, {});
499   auto p2 = Parameter(&builder, 0, elem_shape, "x0");
500   auto p3 = Parameter(&builder, 1, elem_shape, "x1");
501   auto p4 = Parameter(&builder, 2, elem_shape, "y0");
502   auto p5 = Parameter(&builder, 3, elem_shape, "y1");
503   absl::InlinedVector<XlaOp, 2> compute_vec = {Min(p2, p4), Min(p3, p5)};
504   Tuple(&builder, compute_vec);
505   TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build());
506   auto input1 =
507       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1");
508   auto input2 =
509       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2");
510   auto init = ConstantR0<float>(&builder, 0);
511   ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5},
512                Padding::kValid);
513 
514   // Run HLO cost analysis.
515   auto hlo_module = BuildHloGraph(&builder);
516   HloCostAnalysis analysis(ShapeSize);
517   ASSERT_IS_OK(
518       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
519 
520   // Each of [2x4] output elements are generated from reducing [4x5] elements.
521   EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1));
522 
523   EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3));
524 
525   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
526   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20);
527   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
528   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4);
529 }
530 
TEST_F(HloCostAnalysisTest,SelectAndScatter)531 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
532   XlaBuilder builder("select_and_scatter");
533   auto operand =
534       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
535   auto source =
536       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
537   SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source,
538                    ConstantR0<float>(&builder, 0), add_);
539 
540   // Run HLO cost analysis.
541   auto hlo_module = BuildHloGraph(&builder);
542   HloCostAnalysis analysis(ShapeSize);
543   ASSERT_IS_OK(
544       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
545 
546   // Each of [2x4] source elements computes its destination from reducing [4x5]
547   // elements followed by the scatter computation.
548   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
549 
550   EXPECT_EQ(analysis.bytes_accessed(),
551             sizeof(float) * (10 * 20 + 2 * 4 + 1 + 10 * 20));
552 
553   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
554   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
555   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 2 * 4);
556   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 1);
557   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 20);
558 }
559 
TEST_F(HloCostAnalysisTest,Broadcast)560 TEST_F(HloCostAnalysisTest, Broadcast) {
561   XlaBuilder b("broadcast");
562   Broadcast(ConstantR0<float>(&b, 42), {10, 7});
563   auto hlo_module = BuildHloGraph(&b);
564   HloCostAnalysis analysis(ShapeSize);
565   ASSERT_IS_OK(
566       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
567   EXPECT_EQ(analysis.flop_count(), 0);
568 
569   EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (1 + 10 * 7));
570 
571   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
572   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 1);
573   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 7);
574 }
575 
576 // Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest,FullyConnectedForward)577 TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
578   XlaBuilder builder("fully_connected_forward");
579   auto input =
580       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
581   auto weight =
582       Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
583   auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias");
584   // sigmoid(input * weight + bias)
585   Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
586 
587   // Run HLO cost analysis.
588   auto hlo_module = BuildHloGraph(&builder);
589   HloCostAnalysis analysis(ShapeSize);
590   ASSERT_IS_OK(
591       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
592 
593   // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
594   // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
595   EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
596   EXPECT_EQ(analysis.transcendental_count(), 200);
597 }
598 
TEST_F(HloCostAnalysisTest,MatmulAndConvolutionCanBeTheSameComputation)599 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
600   HloCostAnalysis conv_analysis(ShapeSize);
601   {
602     XlaBuilder builder("conv_looking_matmul");
603     auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
604                          "input");
605     auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
606                          "weights");
607     Conv(lhs, rhs, {1, 1}, Padding::kSame);
608     auto hlo_module = BuildHloGraph(&builder);
609     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
610         &conv_analysis));
611   }
612 
613   HloCostAnalysis matmul_analysis(ShapeSize);
614   {
615     XlaBuilder builder("matmul");
616     auto lhs =
617         Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
618     auto rhs =
619         Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
620     Dot(lhs, rhs);
621     auto hlo_module = BuildHloGraph(&builder);
622     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
623         &matmul_analysis));
624   }
625 
626   EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
627 }
628 
629 using FusionCostAnalysis = HloTestBase;
630 
TEST_F(FusionCostAnalysis,LoopFusionDynUpdateSlice)631 TEST_F(FusionCostAnalysis, LoopFusionDynUpdateSlice) {
632   // Test for b/234935631.
633   // DynamicUpdateSlice within a loop fusion needs to respect operand-output
634   // aliasing.
635   const char* hlo_fusion_module_str = R"(
636   HloModule module
637 
638   _.1 {
639     tmp_0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
640     tmp_1 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
641     tmp_2 = s32[]{:T(128)} parameter(1)
642     tmp_3 = s32[]{:T(128)} constant(0)
643     tmp_4 = bf16[1,32,256,1152]{3,2,1,0:T(8,128)(2,1)S(3)} dynamic-slice(tmp_1, tmp_2, tmp_3, tmp_3, tmp_3), dynamic_slice_sizes={1,32,256,1152}
644     tmp_11 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} dynamic-update-slice(tmp_0, tmp_4, tmp_2, tmp_3, tmp_3, tmp_3)
645     ROOT tmp_20 = (bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)}) tuple(tmp_11)
646   }
647 
648   ENTRY _ {
649     _0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
650     _1 = s32[]{:T(128)} parameter(1)
651     _4 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
652     ROOT _ = (bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)}) fusion(_0, _1, _4), kind=kLoop, calls=_.1
653   }
654   )";
655 
656   TF_ASSERT_OK_AND_ASSIGN(auto module,
657                           ParseAndReturnVerifiedModule(hlo_fusion_module_str));
658   HloCostAnalysis fusion_analysis(ShapeSize);
659 
660   HloInstruction* fusion = module->entry_computation()->root_instruction();
661   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
662 
663   const char* hlo_dus_module_str = R"(
664   HloModule module
665 
666   ENTRY _ {
667     _0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
668     _1 = s32[]{:T(128)} parameter(1)
669     _2 = bf16[1,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
670     ROOT _ = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} dynamic-update-slice(_0, _2, _1, _1, _1, _1)
671   }
672   )";
673   TF_ASSERT_OK_AND_ASSIGN(auto dus_module,
674                           ParseAndReturnVerifiedModule(hlo_dus_module_str));
675   HloCostAnalysis dus_analysis(ShapeSize);
676   auto dus = dus_module->entry_computation()->root_instruction();
677   ASSERT_IS_OK(dus->Accept(&dus_analysis));
678   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0), 0);
679   EXPECT_EQ(fusion_analysis.bytes_accessed(), dus_analysis.bytes_accessed());
680   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
681             dus_analysis.operand_bytes_accessed(*dus, 0));
682   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
683             dus_analysis.operand_bytes_accessed(*dus, 2));
684   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
685             dus_analysis.operand_bytes_accessed(*dus, 1));
686   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
687             dus_analysis.output_bytes_accessed(*dus));
688 }
689 
TEST_F(FusionCostAnalysis,LoopFusion)690 TEST_F(FusionCostAnalysis, LoopFusion) {
691   // Do this 4 times with different per-second rates to test the computation of
692   // bottleneck time on fusion nodes.
693   for (int i = 0; i < 4; ++i) {
694     Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
695 
696     // Fuse all instructions in complicated expression:
697     //
698     //   add = Add(C1, C2)
699     //   clamp = Clamp(C2, add, add)
700     //   exp = Exp(add)
701     //   mul = Mul(exp, C3)
702     //   sub = Sub(mul, clamp)
703     //   tuple = Tuple({sub, sub, mul, C1})
704     HloComputation::Builder builder(TestName());
705     auto c1 = builder.AddInstruction(
706         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
707             /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
708     auto c2 = builder.AddInstruction(
709         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
710             /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
711     auto c3 = builder.AddInstruction(
712         HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
713             /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
714     auto add = builder.AddInstruction(
715         HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
716     auto clamp = builder.AddInstruction(
717         HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
718     auto exp = builder.AddInstruction(
719         HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
720     auto mul = builder.AddInstruction(
721         HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
722     auto sub = builder.AddInstruction(
723         HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
724     auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
725 
726     auto module = CreateNewVerifiedModule();
727     auto* computation = module->AddEntryComputation(builder.Build());
728     auto* fusion = computation->CreateFusionInstruction(
729         {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
730 
731     // The time given these rates at i == 0 is exactly even among the properties
732     // at 1.0 seconds. For other values, one of the rates is slower so that it
733     // becomes the bottleneck.
734     HloCostAnalysis::Options options{ShapeSize};
735     options.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
736     options.set_transcendentals_per_second(4 * (i == 2 ? 1 / 4.0 : 1.0));
737     options.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
738     HloCostAnalysis fusion_analysis(options);
739     ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
740 
741     EXPECT_EQ(fusion_analysis.flop_count(), 16);
742     EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
743     constexpr int64_t bytes_accessed = sizeof(float) * 4 * 2 * 2;
744     static_assert(bytes_accessed == 64, "");
745     EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
746 
747     EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
748               sizeof(float) * 2 * 2);
749     EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
750               sizeof(float) * 2 * 2);
751     EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
752               sizeof(float) * 2 * 2);
753     EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
754               sizeof(float) * 2 * 2);
755 
756     EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
757   }
758 }
759 
TEST_F(FusionCostAnalysis,LoopFusionTupleOutput)760 TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) {
761   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
762 
763   // Same as above but the fusion outputs a tuple.
764   HloComputation::Builder builder(TestName());
765   auto c1 = builder.AddInstruction(
766       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
767           /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
768   auto c2 = builder.AddInstruction(
769       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
770           /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
771   auto c3 = builder.AddInstruction(
772       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
773           /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
774   auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({c1, c2}));
775   auto add = builder.AddInstruction(
776       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
777   auto clamp = builder.AddInstruction(
778       HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
779   auto exp = builder.AddInstruction(
780       HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
781   auto mul = builder.AddInstruction(
782       HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
783   auto sub = builder.AddInstruction(
784       HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
785   auto tuple2 = builder.AddInstruction(
786       HloInstruction::CreateTuple({sub, sub, mul, tuple1}));
787 
788   auto module = CreateNewVerifiedModule();
789   auto* computation = module->AddEntryComputation(builder.Build());
790   auto* fusion = computation->CreateFusionInstruction(
791       {tuple2, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
792 
793   HloCostAnalysis fusion_analysis(ShapeSize);
794   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
795 
796   EXPECT_EQ(fusion_analysis.flop_count(), 16);
797   EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
798   EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
799             sizeof(float) * (5 + 5) * 2 * 2);
800 
801   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
802             sizeof(float) * 2 * 2 * 2);
803   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
804             sizeof(float) * 2 * 2);
805   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
806             sizeof(float) * 2 * 2);
807   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 3),
808             sizeof(float) * 2 * 2);
809   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
810             sizeof(float) * 5 * 2 * 2);
811   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {0}),
812             sizeof(float) * 2 * 2);
813   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {1}),
814             sizeof(float) * 2 * 2);
815   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {2}),
816             sizeof(float) * 2 * 2);
817   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3}),
818             sizeof(float) * 2 * 2 * 2);
819   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 0}),
820             sizeof(float) * 2 * 2);
821   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 1}),
822             sizeof(float) * 2 * 2);
823 }
824 
TEST_F(FusionCostAnalysis,NoLayout)825 TEST_F(FusionCostAnalysis, NoLayout) {
826   Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
827   // Instructions within a fused op may have no layout.
828   Shape shape_without_layout = shape_with_layout;
829   shape_without_layout.clear_layout();
830 
831   HloComputation::Builder builder(TestName());
832   auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
833       LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
834   auto c2 = builder.AddInstruction(
835       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
836 
837   auto broadcast = builder.AddInstruction(
838       HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
839   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
840       shape_with_layout, HloOpcode::kAdd, c1, broadcast));
841 
842   auto module = CreateNewVerifiedModule();
843   auto* computation = module->AddEntryComputation(builder.Build());
844   auto* fusion = computation->CreateFusionInstruction(
845       {add, broadcast}, HloInstruction::FusionKind::kLoop);
846 
847   HloCostAnalysis fusion_analysis(ShapeSize);
848   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
849 
850   EXPECT_EQ(fusion_analysis.flop_count(), 120);
851   EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
852 
853   EXPECT_EQ(fusion_analysis.bytes_accessed(),
854             sizeof(float) * (2 * 3 * 4 * 5 + 3 + 2 * 3 * 4 * 5));
855 
856   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
857             sizeof(float) * 2 * 3 * 4 * 5);
858   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
859             sizeof(float) * 3);
860   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
861             sizeof(float) * 2 * 3 * 4 * 5);
862 }
863 
TEST_F(FusionCostAnalysis,NonTupleWithTupleParamBytesAccessed)864 TEST_F(FusionCostAnalysis, NonTupleWithTupleParamBytesAccessed) {
865   absl::string_view hlo_string = R"(
866 HloModule module, is_scheduled=true
867 
868 fused_computation {
869   param = (f32[3,2]{1,0}, f32[3,2]{1,0}) parameter(0)
870   gte0 = f32[3,2]{1,0} get-tuple-element(param), index=0
871   gte1 = f32[3,2]{1,0} get-tuple-element(param), index=1
872   ROOT add = f32[3,2]{1,0} add(gte0, gte1)
873 }
874 
875 ENTRY entry {
876   param0 = f32[3,2]{1,0} parameter(0)
877   param1 = f32[3,2]{1,0} parameter(1)
878   tuple = (f32[3,2]{1,0}, f32[3,2]{1,0}) tuple(param0, param1)
879   ROOT fusion = f32[3,2]{1,0} fusion(tuple), kind=kLoop, calls=fused_computation
880 }
881 )";
882   TF_ASSERT_OK_AND_ASSIGN(auto module,
883                           ParseAndReturnVerifiedModule(hlo_string));
884 
885   HloInstruction* fusion = module->entry_computation()->root_instruction();
886 
887   HloCostAnalysis fusion_analysis(ShapeSize);
888   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
889 
890   EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 3 * 2 * 3);
891   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
892             sizeof(float) * 3 * 2 * 2);
893   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
894             sizeof(float) * 3 * 2);
895 }
896 
TEST_F(FusionCostAnalysis,TupleBytesAccessed)897 TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
898   absl::string_view hlo_string = R"(
899 HloModule module, is_scheduled=true
900 
901 fused_computation {
902   param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
903   gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
904   gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
905   add = f32[2,2]{1,0} add(gte0, gte1)
906   mul = f32[2,2]{1,0} multiply(gte0, gte1)
907   ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
908 }
909 
910 ENTRY entry {
911   param0 = f32[2,2]{1,0} parameter(0)
912   param1 = f32[2,2]{1,0} parameter(1)
913   tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
914   ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
915 }
916 )";
917   TF_ASSERT_OK_AND_ASSIGN(auto module,
918                           ParseAndReturnVerifiedModule(hlo_string));
919 
920   HloInstruction* fusion = module->entry_computation()->root_instruction();
921 
922   HloCostAnalysis fusion_analysis(ShapeSize);
923   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
924 
925   EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
926   EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
927             sizeof(float) * 2 * 2 * 2);
928   EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
929             sizeof(float) * 2 * 2 * 2);
930 }
931 
TEST_F(FusionCostAnalysis,InfeedOutfeed)932 TEST_F(FusionCostAnalysis, InfeedOutfeed) {
933   absl::string_view hlo_string = R"(
934 HloModule module, is_scheduled=true
935 
936 ENTRY entry {
937   after-all = token[] after-all()
938   infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
939   gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
940   gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
941   add = f32[2,3]{1,0} add(gte1, gte1)
942   tuple = (f32[2,3]{1,0}) tuple(add)
943   tok = token[] get-tuple-element(infeed), index=1
944   ROOT outfeed = token[] outfeed(tuple, tok)
945 }
946 )";
947   TF_ASSERT_OK_AND_ASSIGN(auto module,
948                           ParseAndReturnVerifiedModule(hlo_string));
949 
950   HloInstruction* infeed =
951       module->entry_computation()->GetInstructionWithName("infeed");
952   HloInstruction* outfeed =
953       module->entry_computation()->GetInstructionWithName("outfeed");
954 
955   HloCostAnalysis analysis(ShapeSize);
956   ASSERT_IS_OK(infeed->Accept(&analysis));
957   ASSERT_IS_OK(outfeed->Accept(&analysis));
958 
959   EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
960   EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
961   EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
962 
963   EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
964   EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
965             sizeof(float) * 2 * 3);
966   EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
967 }
968 
TEST_F(FusionCostAnalysis,AllReduceTupleBytesAccessed)969 TEST_F(FusionCostAnalysis, AllReduceTupleBytesAccessed) {
970   absl::string_view hlo_string = R"(
971 HloModule module, is_scheduled=true
972 
973 sum {
974   lhs = f32[] parameter(0)
975   rhs = f32[] parameter(1)
976   ROOT add = f32[] add(lhs, rhs)
977 }
978 
979 ENTRY entry {
980   param0 = f32[2,2]{1,0} parameter(0)
981   param1 = f32[2,2]{1,0} parameter(1)
982   ROOT all-reduce = (f32[2,2]{1,0}, f32[2,2]{1,0}) all-reduce(param0, param1), replica_groups={{0,1}}, to_apply=sum
983 }
984 )";
985   TF_ASSERT_OK_AND_ASSIGN(auto module,
986                           ParseAndReturnVerifiedModule(hlo_string));
987 
988   HloInstruction* all_reduce = module->entry_computation()->root_instruction();
989 
990   HloCostAnalysis all_reduce_analysis(ShapeSize);
991   ASSERT_IS_OK(all_reduce->Accept(&all_reduce_analysis));
992 
993   EXPECT_EQ(all_reduce_analysis.bytes_accessed(*all_reduce),
994             sizeof(float) * 2 * 2 * 4);
995   EXPECT_EQ(all_reduce_analysis.operand_bytes_accessed(*all_reduce, 0),
996             sizeof(float) * 2 * 2);
997   EXPECT_EQ(all_reduce_analysis.operand_bytes_accessed(*all_reduce, 1),
998             sizeof(float) * 2 * 2);
999   EXPECT_EQ(all_reduce_analysis.output_bytes_accessed(*all_reduce),
1000             sizeof(float) * 2 * 2 * 2);
1001 }
1002 
TEST_F(HloCostAnalysisTest,TupleCost)1003 TEST_F(HloCostAnalysisTest, TupleCost) {
1004   HloCostAnalysis analysis(ShapeSize);
1005 
1006   XlaBuilder builder("tuple");
1007   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
1008   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
1009   Tuple(&builder, {x, y});
1010   auto hlo_module = BuildHloGraph(&builder);
1011 
1012   ASSERT_IS_OK(
1013       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1014 
1015   EXPECT_EQ(analysis.flop_count(), 0);
1016   EXPECT_EQ(analysis.transcendental_count(), 0);
1017   EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
1018 
1019   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1020   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
1021   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), 0);
1022   EXPECT_EQ(analysis.output_bytes_accessed(*root), kPointerSize * 2);
1023 }
1024 
1025 using DomainCostAnalysis = HloTestBase;
TEST_F(DomainCostAnalysis,DomainCost)1026 TEST_F(DomainCostAnalysis, DomainCost) {
1027   HloCostAnalysis analysis(ShapeSize);
1028 
1029   HloComputation::Builder builder("domain");
1030   auto x = builder.AddInstruction(HloInstruction::CreateParameter(
1031       0, ShapeUtil::MakeShape(F32, {123}), "x"));
1032   auto y = builder.AddInstruction(
1033       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
1034   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
1035   auto domain = builder.AddInstruction(
1036       HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
1037 
1038   auto hlo_module = CreateNewVerifiedModule();
1039   hlo_module->AddEntryComputation(builder.Build());
1040 
1041   EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
1042   ASSERT_IS_OK(domain->Accept(&analysis));
1043 
1044   EXPECT_EQ(analysis.flop_count(*domain), 0);
1045   EXPECT_EQ(analysis.transcendental_count(*domain), 0);
1046   EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
1047 }
1048 
TEST_F(HloCostAnalysisTest,BaseDilatedConvolution)1049 TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
1050   XlaBuilder builder("BaseDilatedConvolution");
1051   auto input = Parameter(
1052       &builder, 0,
1053       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
1054                                  /*x_dim=*/20}),
1055       "input");
1056   auto kernel = Parameter(
1057       &builder, 1,
1058       ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
1059                                  /*x_dim=*/3}),
1060       "kernel");
1061 
1062   ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
1063                      /*padding=*/{{1, 1}, {1, 1}},
1064                      /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
1065                      XlaBuilder::CreateDefaultConvDimensionNumbers(2));
1066 
1067   // Run HLO cost analysis.
1068   auto hlo_module = BuildHloGraph(&builder);
1069   HloCostAnalysis analysis(ShapeSize);
1070   ASSERT_IS_OK(
1071       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1072 
1073   EXPECT_EQ(analysis.flop_count(), 1472);
1074 }
1075 
TEST_F(HloCostAnalysisTest,Slice)1076 TEST_F(HloCostAnalysisTest, Slice) {
1077   // Test the analysis on a slice.
1078   XlaBuilder builder("slice");
1079   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1080   Slice(x, {0}, {1}, {1});
1081   auto hlo_module = BuildHloGraph(&builder);
1082 
1083   // Run HLO cost analysis.
1084   HloCostAnalysis analysis(ShapeSize);
1085   ASSERT_IS_OK(
1086       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1087 
1088   EXPECT_EQ(analysis.bytes_accessed(), 8);
1089 
1090   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1091   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
1092   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1093 }
1094 
TEST_F(HloCostAnalysisTest,DynamicSlice)1095 TEST_F(HloCostAnalysisTest, DynamicSlice) {
1096   // Test the analysis on a slice.
1097   XlaBuilder builder("dynamic-slice");
1098   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1099   DynamicSlice(x, absl::Span<const XlaOp>({ConstantR0<int32_t>(&builder, 1)}),
1100                {1});
1101   auto hlo_module = BuildHloGraph(&builder);
1102 
1103   // Run HLO cost analysis.
1104   HloCostAnalysis analysis(ShapeSize);
1105   ASSERT_IS_OK(
1106       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1107 
1108   EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
1109 
1110   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1111   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
1112   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t));
1113   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1114 }
1115 
TEST_F(HloCostAnalysisTest,DynamicUpdateSlice)1116 TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
1117   // Test the analysis on a slice.
1118   XlaBuilder builder("dynamic-update-slice");
1119   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1120   DynamicUpdateSlice(
1121       x, ConstantR1<float>(&builder, {1.0}),
1122       absl::Span<const XlaOp>({ConstantR0<int32_t>(&builder, 1)}));
1123   auto hlo_module = BuildHloGraph(&builder);
1124 
1125   // Run HLO cost analysis.
1126   HloCostAnalysis analysis(ShapeSize);
1127   ASSERT_IS_OK(
1128       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1129 
1130   EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
1131 
1132   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1133 
1134   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
1135   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float));
1136   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(int32_t));
1137   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1138 }
1139 
TEST_F(HloCostAnalysisTest,Gather)1140 TEST_F(HloCostAnalysisTest, Gather) {
1141   // Test the analysis on a gather.
1142   XlaBuilder builder("gather");
1143   Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
1144   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1145 
1146   auto operand = Parameter(&builder, 0, operand_shape, "operand");
1147   auto indices = Parameter(&builder, 1, indices_shape, "indices");
1148   GatherDimensionNumbers dim_numbers;
1149   dim_numbers.add_offset_dims(1);
1150   dim_numbers.add_collapsed_slice_dims(0);
1151   dim_numbers.add_start_index_map(0);
1152   dim_numbers.set_index_vector_dim(1);
1153   Gather(operand, indices, dim_numbers, {1, 3});
1154 
1155   auto hlo_module = BuildHloGraph(&builder);
1156 
1157   // Run HLO cost analysis.
1158   HloCostAnalysis analysis(ShapeSize);
1159   ASSERT_IS_OK(
1160       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1161 
1162   EXPECT_EQ(analysis.bytes_accessed(), 56);
1163 
1164   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1165   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1166   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2);
1167   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
1168 }
1169 
TEST_F(HloCostAnalysisTest,Scatter)1170 TEST_F(HloCostAnalysisTest, Scatter) {
1171   // Test the analysis on a scatter.
1172   XlaBuilder builder("scatter");
1173   Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3});
1174   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1175   Shape values_shape = ShapeUtil::MakeShape(F32, {2, 3});
1176 
1177   auto operand = Parameter(&builder, 0, operand_shape, "operand");
1178   auto indices = Parameter(&builder, 1, indices_shape, "indices");
1179   auto values = Parameter(&builder, 2, values_shape, "values");
1180   ScatterDimensionNumbers dim_numbers;
1181   dim_numbers.set_index_vector_dim(1);
1182   dim_numbers.add_update_window_dims(1);
1183   dim_numbers.add_inserted_window_dims(0);
1184   dim_numbers.add_scatter_dims_to_operand_dims(0);
1185   Scatter(operand, indices, values, add_, dim_numbers);
1186 
1187   auto hlo_module = BuildHloGraph(&builder);
1188 
1189   // Run HLO cost analysis.
1190   HloCostAnalysis analysis(ShapeSize);
1191   ASSERT_IS_OK(
1192       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1193 
1194   EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3)));
1195 
1196   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1197   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1198   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2);
1199   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 2 * 3);
1200   EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
1201 }
1202 
TEST_F(HloCostAnalysisTest,MultioutputScatter)1203 TEST_F(HloCostAnalysisTest, MultioutputScatter) {
1204   // Test the analysis on a scatter.
1205   XlaBuilder builder("scatter");
1206   Shape operand0_shape = ShapeUtil::MakeShape(F32, {3, 3});
1207   Shape operand1_shape = ShapeUtil::MakeShape(S32, {3, 3});
1208   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1209   Shape values0_shape = ShapeUtil::MakeShape(F32, {2, 3});
1210   Shape values1_shape = ShapeUtil::MakeShape(S32, {2, 3});
1211 
1212   auto operand0 = Parameter(&builder, 0, operand0_shape, "operand0");
1213   auto operand1 = Parameter(&builder, 1, operand1_shape, "operand1");
1214   auto indices = Parameter(&builder, 2, indices_shape, "indices");
1215   auto values0 = Parameter(&builder, 3, values0_shape, "values0");
1216   auto values1 = Parameter(&builder, 4, values1_shape, "values1");
1217   ScatterDimensionNumbers dim_numbers;
1218   dim_numbers.set_index_vector_dim(1);
1219   dim_numbers.add_update_window_dims(1);
1220   dim_numbers.add_inserted_window_dims(0);
1221   dim_numbers.add_scatter_dims_to_operand_dims(0);
1222   auto add = [] {
1223     XlaBuilder builder("add");
1224     auto x0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x0");
1225     auto x1 = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "x1");
1226     auto y0 = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "y0");
1227     auto y1 = Parameter(&builder, 3, ShapeUtil::MakeShape(S32, {}), "y1");
1228     Tuple(&builder, {Add(x0, y0), Add(x1, y1)});
1229     auto computation_status = builder.Build();
1230     TF_CHECK_OK(computation_status.status());
1231     return std::move(computation_status).ValueOrDie();
1232   }();
1233   Scatter({operand0, operand1}, indices, {values0, values1}, add, dim_numbers);
1234 
1235   auto hlo_module = BuildHloGraph(&builder);
1236 
1237   // Run HLO cost analysis.
1238   HloCostAnalysis analysis(ShapeSize);
1239   ASSERT_IS_OK(
1240       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1241 
1242   EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * 3 * (2 * 3)));
1243 
1244   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1245   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1246   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2 * 3);
1247   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(int32_t) * 2);
1248   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 3), sizeof(float) * 2 * 3);
1249   EXPECT_EQ(analysis.operand_bytes_accessed(*root, 4), sizeof(int32_t) * 2 * 3);
1250   EXPECT_EQ(analysis.output_bytes_accessed(*root), 2 * sizeof(float) * 2 * 3);
1251 }
1252 
1253 }  // namespace
1254 }  // namespace xla
1255