1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/padding.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/local_service.h"
29 #include "tensorflow/compiler/xla/service/service.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38 namespace {
39
40 constexpr int64_t kPointerSize = 8;
41
ShapeSize(const Shape & shape)42 int64_t ShapeSize(const Shape& shape) {
43 return ShapeUtil::ByteSizeOf(shape, kPointerSize);
44 }
45
46 // This test suite tests the HLO cost analysis by first building a computation
47 // using the client computation builder and running the HloCostAnalysis that
48 // returns the number of floating point and transcendental operations in the
49 // graph. We test both individual HLO operations as well as a mixed graph.
50 class HloCostAnalysisTest : public ::testing::Test {
51 protected:
HloCostAnalysisTest()52 HloCostAnalysisTest()
53 : client_(ClientLibrary::LocalClientOrDie()),
54 // Accessing service instance is required for the unit tests to enable
55 // whitebox accesses to the user computation built from the client,
56 // as shown in the BuildHloGraph functions below.
57 service_(static_cast<Service*>(ClientLibrary::GetXlaService(
58 static_cast<LocalClient*>(client_)->platform()))) {
59 // Create a computation for a unary user function: x => exp(x + 0.5)
60 {
61 XlaBuilder builder("add_and_exp");
62 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
63 auto half = ConstantR0<float>(&builder, 0.5);
64 Exp(Add(x, half));
65 auto computation_status = builder.Build();
66 TF_CHECK_OK(computation_status.status());
67 add_and_exp_ = std::move(computation_status).value();
68 }
69
70 // Create a computation for a binary user function: (x, y) => x + y
71 {
72 XlaBuilder builder("add");
73 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
74 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
75 Add(x, y);
76 auto computation_status = builder.Build();
77 TF_CHECK_OK(computation_status.status());
78 add_ = std::move(computation_status).value();
79 }
80
81 // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
82 {
83 XlaBuilder builder("sigmoid");
84 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
85 auto one = ConstantR0<float>(&builder, 1.0);
86 Div(one, Add(one, Exp(Neg(x))));
87 auto computation_status = builder.Build();
88 TF_CHECK_OK(computation_status.status());
89 sigmoid_ = std::move(computation_status).value();
90 }
91
92 // Create a computation for a binary max function: (x, y) => max (x, y)
93 {
94 XlaBuilder builder("max");
95 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
96 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
97 Max(x, y);
98 auto computation_status = builder.Build();
99 TF_CHECK_OK(computation_status.status());
100 max_ = std::move(computation_status).value();
101 }
102
103 // Create a computation for a binary GT function: (x, y) => x > y
104 {
105 XlaBuilder builder("gt");
106 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
107 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
108 Gt(x, y);
109 auto computation_status = builder.Build();
110 TF_CHECK_OK(computation_status.status());
111 gt_ = std::move(computation_status).value();
112 }
113 }
114
115 // Build HLO graph from the given builder and return the HLO module.
BuildHloGraph(XlaBuilder * builder)116 std::unique_ptr<HloModule> BuildHloGraph(XlaBuilder* builder) {
117 auto computation_status = builder->Build();
118 TF_CHECK_OK(computation_status.status());
119 auto computation = std::move(computation_status).value();
120 auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
121 DebugOptions())
122 .value();
123 return HloModule::CreateFromProto(computation.proto(), config).value();
124 }
125
126 Client* client_;
127 Service* service_;
128
129 // User computations used for higher order operations (e.g., Map, Reduce).
130 XlaComputation add_;
131 XlaComputation add_and_exp_;
132 XlaComputation sigmoid_;
133 XlaComputation max_;
134 XlaComputation gt_;
135 };
136
137 using HloCostAnalysisHloTest = HloTestBase;
138
TEST_F(HloCostAnalysisTest,MatrixMultiply)139 TEST_F(HloCostAnalysisTest, MatrixMultiply) {
140 XlaBuilder builder("matrix_multiply");
141 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
142 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
143 Dot(lhs, rhs);
144
145 // Run HLO cost analysis.
146 auto hlo_module = BuildHloGraph(&builder);
147 HloCostAnalysis analysis(ShapeSize);
148 ASSERT_IS_OK(
149 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
150
151 // Check the number of computations returned from the analysis (1500 FMAs).
152 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
153
154 EXPECT_EQ(analysis.transcendental_count(), 0);
155
156 // Bytes accessed is sum of inputs and output.
157 EXPECT_EQ(analysis.bytes_accessed(),
158 sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
159
160 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
161 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
162 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
163 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
164 }
165
TEST_F(HloCostAnalysisTest,DotGeneral)166 TEST_F(HloCostAnalysisTest, DotGeneral) {
167 XlaBuilder builder("matrix_multiply");
168 auto lhs =
169 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
170 auto rhs =
171 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
172 DotDimensionNumbers dnums;
173 dnums.add_lhs_contracting_dimensions(1);
174 dnums.add_lhs_contracting_dimensions(2);
175 dnums.add_rhs_contracting_dimensions(0);
176 dnums.add_rhs_contracting_dimensions(1);
177 DotGeneral(lhs, rhs, dnums);
178
179 // Run HLO cost analysis.
180 auto hlo_module = BuildHloGraph(&builder);
181 HloCostAnalysis analysis(ShapeSize);
182 ASSERT_IS_OK(
183 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
184
185 // Check the number of computations returned from the analysis (1500 FMAs).
186 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
187
188 EXPECT_EQ(analysis.transcendental_count(), 0);
189
190 // Bytes accessed is sum of inputs and output.
191 EXPECT_EQ(analysis.bytes_accessed(),
192 sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30));
193
194 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
195 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
196 sizeof(float) * 10 * 5 * 5);
197 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
198 sizeof(float) * 5 * 5 * 30);
199 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 30);
200 }
201
TEST_F(HloCostAnalysisTest,DotGeneral2)202 TEST_F(HloCostAnalysisTest, DotGeneral2) {
203 XlaBuilder builder("matrix_multiply");
204 auto lhs =
205 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs");
206 auto rhs =
207 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs");
208 DotDimensionNumbers dnums;
209 dnums.add_lhs_contracting_dimensions(1);
210 dnums.add_lhs_batch_dimensions(2);
211 dnums.add_rhs_contracting_dimensions(0);
212 dnums.add_rhs_batch_dimensions(1);
213 DotGeneral(lhs, rhs, dnums);
214
215 // Run HLO cost analysis.
216 auto hlo_module = BuildHloGraph(&builder);
217 HloCostAnalysis analysis(ShapeSize);
218 ASSERT_IS_OK(
219 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
220
221 // Check the number of computations returned from the analysis (1500 FMAs).
222 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
223
224 EXPECT_EQ(analysis.transcendental_count(), 0);
225
226 // Bytes accessed is sum of inputs and output.
227 EXPECT_EQ(analysis.bytes_accessed(),
228 sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30));
229
230 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
231 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
232 sizeof(float) * 10 * 5 * 5);
233 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
234 sizeof(float) * 5 * 5 * 30);
235 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 5 * 10 * 30);
236 }
237
TEST_F(HloCostAnalysisTest,DotGeneral3)238 TEST_F(HloCostAnalysisTest, DotGeneral3) {
239 XlaBuilder builder("matrix_multiply");
240 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
241 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
242 DotDimensionNumbers dnums;
243 DotGeneral(lhs, rhs, dnums);
244
245 // Run HLO cost analysis.
246 auto hlo_module = BuildHloGraph(&builder);
247 HloCostAnalysis analysis(ShapeSize);
248 ASSERT_IS_OK(
249 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
250
251 // Check the number of computations returned from the analysis (1500 FMAs).
252 EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5);
253
254 EXPECT_EQ(analysis.transcendental_count(), 0);
255
256 // Bytes accessed is sum of inputs and output.
257 EXPECT_EQ(analysis.bytes_accessed(),
258 sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30));
259
260 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
261 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 5);
262 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 5 * 30);
263 EXPECT_EQ(analysis.output_bytes_accessed(*root),
264 sizeof(float) * 5 * 5 * 10 * 30);
265 }
266
TEST_F(HloCostAnalysisTest,Map)267 TEST_F(HloCostAnalysisTest, Map) {
268 XlaBuilder builder("map");
269 auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in");
270 Map(&builder, {input}, add_and_exp_, {0});
271
272 // Run HLO cost analysis.
273 auto hlo_module = BuildHloGraph(&builder);
274 HloCostAnalysis analysis(ShapeSize);
275 ASSERT_IS_OK(
276 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
277
278 // add contributes to 10 flops and exp contributes to 10 transcendental ops.
279 EXPECT_EQ(analysis.flop_count(), 10);
280 EXPECT_EQ(analysis.transcendental_count(), 10);
281 EXPECT_EQ(analysis.bytes_accessed(), 80);
282
283 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
284 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10);
285 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
286 }
287
TEST_F(HloCostAnalysisTest,Convolution)288 TEST_F(HloCostAnalysisTest, Convolution) {
289 XlaBuilder builder("convolution");
290 auto input = Parameter(
291 &builder, 0,
292 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
293 /*x_dim=*/20}),
294 "input");
295 auto kernel = Parameter(
296 &builder, 1,
297 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
298 /*x_dim=*/3}),
299 "kernel");
300 Conv(input, kernel, {1, 1}, Padding::kValid);
301
302 // Run HLO cost analysis.
303 auto hlo_module = BuildHloGraph(&builder);
304 HloCostAnalysis analysis(ShapeSize);
305 ASSERT_IS_OK(
306 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
307
308 // Output shape is [1x1x8x18] and each output element requires (3x3)
309 // FMAs and one FMA is 2 flops.
310 EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
311
312 // Bytes accessed is sum of inputs and output.
313 EXPECT_EQ(analysis.bytes_accessed(),
314 sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
315
316 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
317 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
318 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 3 * 3);
319 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 8 * 18);
320 }
321
TEST_F(HloCostAnalysisTest,ConvolutionSame)322 TEST_F(HloCostAnalysisTest, ConvolutionSame) {
323 XlaBuilder builder("convolution_same");
324 const int iw = 3;
325 const int ih = 3;
326 const int kw = 3;
327 const int kh = 3;
328 const int ow = iw;
329 const int oh = ih;
330 const int sx = 1;
331 const int sy = 1;
332 auto input = Parameter(
333 &builder, 0,
334 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/ih,
335 /*x_dim=*/iw}),
336 "input");
337 auto kernel = Parameter(
338 &builder, 1,
339 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kh,
340 /*x_dim=*/kw}),
341 "kernel");
342 Conv(input, kernel, {sx, sy}, Padding::kSame);
343
344 // Run HLO cost analysis.
345 auto hlo_module = BuildHloGraph(&builder);
346 HloCostAnalysis analysis(ShapeSize);
347 ASSERT_IS_OK(
348 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
349
350 // Output shape is [1x1x3x3] and each output element requires (3x3)
351 // FMAs and one FMA is 2 flops.
352 // NOTE: This formula only works for the hard-coded dimensions for now.
353 EXPECT_EQ(analysis.flop_count(), 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4));
354
355 // Bytes accessed is sum of inputs and output.
356 EXPECT_EQ(analysis.bytes_accessed(),
357 sizeof(float) * (iw * ih + kw * kh + ow * oh));
358
359 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
360 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * iw * ih);
361 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * kw * kh);
362 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * ow * oh);
363 }
364
TEST_F(HloCostAnalysisTest,ConvolutionExtreme)365 TEST_F(HloCostAnalysisTest, ConvolutionExtreme) {
366 XlaBuilder builder("convolution");
367 constexpr int64_t kLarge = 512 * 1024;
368 auto input = Parameter(
369 &builder, 0,
370 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
371 "input");
372 auto kernel = Parameter(
373 &builder, 1,
374 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
375 "kernel");
376 ConvGeneralDilated(input, kernel, {kLarge - 1}, {{0, 0}}, {kLarge}, {1},
377 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
378
379 // Run HLO cost analysis.
380 auto hlo_module = BuildHloGraph(&builder);
381 HloCostAnalysis analysis(ShapeSize);
382 ASSERT_IS_OK(
383 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
384
385 EXPECT_EQ(analysis.flop_count(), 2 * kLarge);
386 }
387
TEST_F(HloCostAnalysisTest,ConvolutionExtreme2)388 TEST_F(HloCostAnalysisTest, ConvolutionExtreme2) {
389 XlaBuilder builder("convolution");
390 constexpr int64_t kLarge = 512 * 1024;
391 auto input = Parameter(
392 &builder, 0,
393 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/1}),
394 "input");
395 auto kernel = Parameter(
396 &builder, 1,
397 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/kLarge}),
398 "kernel");
399 ConvGeneralDilated(input, kernel, {1}, {{kLarge - 1, kLarge - 1}}, {1}, {1},
400 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
401
402 // Run HLO cost analysis.
403 auto hlo_module = BuildHloGraph(&builder);
404 HloCostAnalysis analysis(ShapeSize);
405 ASSERT_IS_OK(
406 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
407
408 EXPECT_EQ(analysis.flop_count(), 2 * kLarge);
409 }
410
TEST_F(HloCostAnalysisTest,ConvolutionWithFeatureGroup)411 TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
412 XlaBuilder builder("convolution");
413 auto input = Parameter(
414 &builder, 0,
415 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
416 /*x_dim=*/20}),
417 "input");
418 auto kernel = Parameter(
419 &builder, 1,
420 ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
421 /*x_dim=*/3}),
422 "kernel");
423 Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
424
425 // Run HLO cost analysis.
426 auto hlo_module = BuildHloGraph(&builder);
427 HloCostAnalysis analysis(ShapeSize);
428 ASSERT_IS_OK(
429 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
430
431 // Output shape is [1x120x8x18] and each output element requires (3x3)
432 // FMAs and one FMA is 2 flops.
433 EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
434
435 // Bytes accessed is sum of inputs and output.
436 EXPECT_EQ(analysis.bytes_accessed(),
437 sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
438
439 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
440 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0),
441 sizeof(float) * 120 * 10 * 20);
442 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1),
443 sizeof(float) * 120 * 3 * 3);
444 EXPECT_EQ(analysis.output_bytes_accessed(*root),
445 sizeof(float) * 120 * 8 * 18);
446 }
447
TEST_F(HloCostAnalysisTest,Reduce)448 TEST_F(HloCostAnalysisTest, Reduce) {
449 XlaBuilder builder("reduce");
450 auto input =
451 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
452 Reduce(input, ConstantR0<float>(&builder, 0.0f), add_, {1});
453
454 // Run HLO cost analysis.
455 auto hlo_module = BuildHloGraph(&builder);
456 HloCostAnalysis analysis(ShapeSize);
457 ASSERT_IS_OK(
458 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
459
460 // Subtracting the output size from the input size gives the number of
461 // reduction operations performed.
462 EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
463
464 EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 10));
465
466 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
467 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
468 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
469 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10);
470 }
471
TEST_F(HloCostAnalysisTest,ReduceWindow)472 TEST_F(HloCostAnalysisTest, ReduceWindow) {
473 XlaBuilder builder("reduce_window");
474 auto input =
475 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
476 ReduceWindow(input, ConstantR0<float>(&builder, 0), add_, {4, 5}, {4, 5},
477 Padding::kValid);
478
479 // Run HLO cost analysis.
480 auto hlo_module = BuildHloGraph(&builder);
481 HloCostAnalysis analysis(ShapeSize);
482 ASSERT_IS_OK(
483 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
484
485 // Each of [2x4] output elements are generated from reducing [4x5] elements.
486 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
487
488 EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 + 1 + 2 * 4));
489
490 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
491 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
492 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 1);
493 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 4);
494 }
495
TEST_F(HloCostAnalysisTest,ReduceWindowVariadic)496 TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) {
497 XlaBuilder builder("reduce_window_variadic");
498 auto elem_shape = ShapeUtil::MakeShape(F32, {});
499 auto p2 = Parameter(&builder, 0, elem_shape, "x0");
500 auto p3 = Parameter(&builder, 1, elem_shape, "x1");
501 auto p4 = Parameter(&builder, 2, elem_shape, "y0");
502 auto p5 = Parameter(&builder, 3, elem_shape, "y1");
503 absl::InlinedVector<XlaOp, 2> compute_vec = {Min(p2, p4), Min(p3, p5)};
504 Tuple(&builder, compute_vec);
505 TF_ASSERT_OK_AND_ASSIGN(auto compute_tuple, builder.Build());
506 auto input1 =
507 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input1");
508 auto input2 =
509 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {10, 20}), "input2");
510 auto init = ConstantR0<float>(&builder, 0);
511 ReduceWindow({input1, input2}, {init, init}, compute_tuple, {4, 5}, {4, 5},
512 Padding::kValid);
513
514 // Run HLO cost analysis.
515 auto hlo_module = BuildHloGraph(&builder);
516 HloCostAnalysis analysis(ShapeSize);
517 ASSERT_IS_OK(
518 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
519
520 // Each of [2x4] output elements are generated from reducing [4x5] elements.
521 EXPECT_EQ(analysis.flop_count(), 2 * 4 * 2 * (4 * 5 - 1));
522
523 EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (10 * 20 * 2 + 2 * 3));
524
525 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
526 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 10 * 20);
527 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
528 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 4);
529 }
530
TEST_F(HloCostAnalysisTest,SelectAndScatter)531 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
532 XlaBuilder builder("select_and_scatter");
533 auto operand =
534 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
535 auto source =
536 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
537 SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source,
538 ConstantR0<float>(&builder, 0), add_);
539
540 // Run HLO cost analysis.
541 auto hlo_module = BuildHloGraph(&builder);
542 HloCostAnalysis analysis(ShapeSize);
543 ASSERT_IS_OK(
544 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
545
546 // Each of [2x4] source elements computes its destination from reducing [4x5]
547 // elements followed by the scatter computation.
548 EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
549
550 EXPECT_EQ(analysis.bytes_accessed(),
551 sizeof(float) * (10 * 20 + 2 * 4 + 1 + 10 * 20));
552
553 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
554 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 10 * 20);
555 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float) * 2 * 4);
556 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 1);
557 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 20);
558 }
559
TEST_F(HloCostAnalysisTest,Broadcast)560 TEST_F(HloCostAnalysisTest, Broadcast) {
561 XlaBuilder b("broadcast");
562 Broadcast(ConstantR0<float>(&b, 42), {10, 7});
563 auto hlo_module = BuildHloGraph(&b);
564 HloCostAnalysis analysis(ShapeSize);
565 ASSERT_IS_OK(
566 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
567 EXPECT_EQ(analysis.flop_count(), 0);
568
569 EXPECT_EQ(analysis.bytes_accessed(), sizeof(float) * (1 + 10 * 7));
570
571 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
572 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 1);
573 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 10 * 7);
574 }
575
576 // Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest,FullyConnectedForward)577 TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
578 XlaBuilder builder("fully_connected_forward");
579 auto input =
580 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
581 auto weight =
582 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
583 auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias");
584 // sigmoid(input * weight + bias)
585 Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
586
587 // Run HLO cost analysis.
588 auto hlo_module = BuildHloGraph(&builder);
589 HloCostAnalysis analysis(ShapeSize);
590 ASSERT_IS_OK(
591 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
592
593 // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
594 // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
595 EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
596 EXPECT_EQ(analysis.transcendental_count(), 200);
597 }
598
TEST_F(HloCostAnalysisTest,MatmulAndConvolutionCanBeTheSameComputation)599 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
600 HloCostAnalysis conv_analysis(ShapeSize);
601 {
602 XlaBuilder builder("conv_looking_matmul");
603 auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
604 "input");
605 auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
606 "weights");
607 Conv(lhs, rhs, {1, 1}, Padding::kSame);
608 auto hlo_module = BuildHloGraph(&builder);
609 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
610 &conv_analysis));
611 }
612
613 HloCostAnalysis matmul_analysis(ShapeSize);
614 {
615 XlaBuilder builder("matmul");
616 auto lhs =
617 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
618 auto rhs =
619 Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
620 Dot(lhs, rhs);
621 auto hlo_module = BuildHloGraph(&builder);
622 ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
623 &matmul_analysis));
624 }
625
626 EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
627 }
628
629 using FusionCostAnalysis = HloTestBase;
630
TEST_F(FusionCostAnalysis,LoopFusionDynUpdateSlice)631 TEST_F(FusionCostAnalysis, LoopFusionDynUpdateSlice) {
632 // Test for b/234935631.
633 // DynamicUpdateSlice within a loop fusion needs to respect operand-output
634 // aliasing.
635 const char* hlo_fusion_module_str = R"(
636 HloModule module
637
638 _.1 {
639 tmp_0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
640 tmp_1 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
641 tmp_2 = s32[]{:T(128)} parameter(1)
642 tmp_3 = s32[]{:T(128)} constant(0)
643 tmp_4 = bf16[1,32,256,1152]{3,2,1,0:T(8,128)(2,1)S(3)} dynamic-slice(tmp_1, tmp_2, tmp_3, tmp_3, tmp_3), dynamic_slice_sizes={1,32,256,1152}
644 tmp_11 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} dynamic-update-slice(tmp_0, tmp_4, tmp_2, tmp_3, tmp_3, tmp_3)
645 ROOT tmp_20 = (bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)}) tuple(tmp_11)
646 }
647
648 ENTRY _ {
649 _0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
650 _1 = s32[]{:T(128)} parameter(1)
651 _4 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
652 ROOT _ = (bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)}) fusion(_0, _1, _4), kind=kLoop, calls=_.1
653 }
654 )";
655
656 TF_ASSERT_OK_AND_ASSIGN(auto module,
657 ParseAndReturnVerifiedModule(hlo_fusion_module_str));
658 HloCostAnalysis fusion_analysis(ShapeSize);
659
660 HloInstruction* fusion = module->entry_computation()->root_instruction();
661 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
662
663 const char* hlo_dus_module_str = R"(
664 HloModule module
665
666 ENTRY _ {
667 _0 = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(0)
668 _1 = s32[]{:T(128)} parameter(1)
669 _2 = bf16[1,32,256,1152]{3,2,1,0:T(8,128)(2,1)} parameter(2)
670 ROOT _ = bf16[50,32,256,1152]{3,2,1,0:T(8,128)(2,1)} dynamic-update-slice(_0, _2, _1, _1, _1, _1)
671 }
672 )";
673 TF_ASSERT_OK_AND_ASSIGN(auto dus_module,
674 ParseAndReturnVerifiedModule(hlo_dus_module_str));
675 HloCostAnalysis dus_analysis(ShapeSize);
676 auto dus = dus_module->entry_computation()->root_instruction();
677 ASSERT_IS_OK(dus->Accept(&dus_analysis));
678 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0), 0);
679 EXPECT_EQ(fusion_analysis.bytes_accessed(), dus_analysis.bytes_accessed());
680 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
681 dus_analysis.operand_bytes_accessed(*dus, 0));
682 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
683 dus_analysis.operand_bytes_accessed(*dus, 2));
684 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
685 dus_analysis.operand_bytes_accessed(*dus, 1));
686 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
687 dus_analysis.output_bytes_accessed(*dus));
688 }
689
TEST_F(FusionCostAnalysis,LoopFusion)690 TEST_F(FusionCostAnalysis, LoopFusion) {
691 // Do this 4 times with different per-second rates to test the computation of
692 // bottleneck time on fusion nodes.
693 for (int i = 0; i < 4; ++i) {
694 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
695
696 // Fuse all instructions in complicated expression:
697 //
698 // add = Add(C1, C2)
699 // clamp = Clamp(C2, add, add)
700 // exp = Exp(add)
701 // mul = Mul(exp, C3)
702 // sub = Sub(mul, clamp)
703 // tuple = Tuple({sub, sub, mul, C1})
704 HloComputation::Builder builder(TestName());
705 auto c1 = builder.AddInstruction(
706 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
707 /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
708 auto c2 = builder.AddInstruction(
709 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
710 /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
711 auto c3 = builder.AddInstruction(
712 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
713 /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
714 auto add = builder.AddInstruction(
715 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
716 auto clamp = builder.AddInstruction(
717 HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
718 auto exp = builder.AddInstruction(
719 HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
720 auto mul = builder.AddInstruction(
721 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
722 auto sub = builder.AddInstruction(
723 HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
724 auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
725
726 auto module = CreateNewVerifiedModule();
727 auto* computation = module->AddEntryComputation(builder.Build());
728 auto* fusion = computation->CreateFusionInstruction(
729 {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
730
731 // The time given these rates at i == 0 is exactly even among the properties
732 // at 1.0 seconds. For other values, one of the rates is slower so that it
733 // becomes the bottleneck.
734 HloCostAnalysis::Options options{ShapeSize};
735 options.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
736 options.set_transcendentals_per_second(4 * (i == 2 ? 1 / 4.0 : 1.0));
737 options.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
738 HloCostAnalysis fusion_analysis(options);
739 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
740
741 EXPECT_EQ(fusion_analysis.flop_count(), 16);
742 EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
743 constexpr int64_t bytes_accessed = sizeof(float) * 4 * 2 * 2;
744 static_assert(bytes_accessed == 64, "");
745 EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
746
747 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
748 sizeof(float) * 2 * 2);
749 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
750 sizeof(float) * 2 * 2);
751 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
752 sizeof(float) * 2 * 2);
753 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
754 sizeof(float) * 2 * 2);
755
756 EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
757 }
758 }
759
TEST_F(FusionCostAnalysis,LoopFusionTupleOutput)760 TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) {
761 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
762
763 // Same as above but the fusion outputs a tuple.
764 HloComputation::Builder builder(TestName());
765 auto c1 = builder.AddInstruction(
766 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
767 /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
768 auto c2 = builder.AddInstruction(
769 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
770 /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
771 auto c3 = builder.AddInstruction(
772 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
773 /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
774 auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({c1, c2}));
775 auto add = builder.AddInstruction(
776 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
777 auto clamp = builder.AddInstruction(
778 HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
779 auto exp = builder.AddInstruction(
780 HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
781 auto mul = builder.AddInstruction(
782 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
783 auto sub = builder.AddInstruction(
784 HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
785 auto tuple2 = builder.AddInstruction(
786 HloInstruction::CreateTuple({sub, sub, mul, tuple1}));
787
788 auto module = CreateNewVerifiedModule();
789 auto* computation = module->AddEntryComputation(builder.Build());
790 auto* fusion = computation->CreateFusionInstruction(
791 {tuple2, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
792
793 HloCostAnalysis fusion_analysis(ShapeSize);
794 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
795
796 EXPECT_EQ(fusion_analysis.flop_count(), 16);
797 EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
798 EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion),
799 sizeof(float) * (5 + 5) * 2 * 2);
800
801 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
802 sizeof(float) * 2 * 2 * 2);
803 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
804 sizeof(float) * 2 * 2);
805 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 2),
806 sizeof(float) * 2 * 2);
807 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 3),
808 sizeof(float) * 2 * 2);
809 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
810 sizeof(float) * 5 * 2 * 2);
811 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {0}),
812 sizeof(float) * 2 * 2);
813 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {1}),
814 sizeof(float) * 2 * 2);
815 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {2}),
816 sizeof(float) * 2 * 2);
817 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3}),
818 sizeof(float) * 2 * 2 * 2);
819 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 0}),
820 sizeof(float) * 2 * 2);
821 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion, {3, 1}),
822 sizeof(float) * 2 * 2);
823 }
824
TEST_F(FusionCostAnalysis,NoLayout)825 TEST_F(FusionCostAnalysis, NoLayout) {
826 Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
827 // Instructions within a fused op may have no layout.
828 Shape shape_without_layout = shape_with_layout;
829 shape_without_layout.clear_layout();
830
831 HloComputation::Builder builder(TestName());
832 auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
833 LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
834 auto c2 = builder.AddInstruction(
835 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
836
837 auto broadcast = builder.AddInstruction(
838 HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
839 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
840 shape_with_layout, HloOpcode::kAdd, c1, broadcast));
841
842 auto module = CreateNewVerifiedModule();
843 auto* computation = module->AddEntryComputation(builder.Build());
844 auto* fusion = computation->CreateFusionInstruction(
845 {add, broadcast}, HloInstruction::FusionKind::kLoop);
846
847 HloCostAnalysis fusion_analysis(ShapeSize);
848 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
849
850 EXPECT_EQ(fusion_analysis.flop_count(), 120);
851 EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
852
853 EXPECT_EQ(fusion_analysis.bytes_accessed(),
854 sizeof(float) * (2 * 3 * 4 * 5 + 3 + 2 * 3 * 4 * 5));
855
856 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
857 sizeof(float) * 2 * 3 * 4 * 5);
858 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 1),
859 sizeof(float) * 3);
860 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
861 sizeof(float) * 2 * 3 * 4 * 5);
862 }
863
TEST_F(FusionCostAnalysis,NonTupleWithTupleParamBytesAccessed)864 TEST_F(FusionCostAnalysis, NonTupleWithTupleParamBytesAccessed) {
865 absl::string_view hlo_string = R"(
866 HloModule module, is_scheduled=true
867
868 fused_computation {
869 param = (f32[3,2]{1,0}, f32[3,2]{1,0}) parameter(0)
870 gte0 = f32[3,2]{1,0} get-tuple-element(param), index=0
871 gte1 = f32[3,2]{1,0} get-tuple-element(param), index=1
872 ROOT add = f32[3,2]{1,0} add(gte0, gte1)
873 }
874
875 ENTRY entry {
876 param0 = f32[3,2]{1,0} parameter(0)
877 param1 = f32[3,2]{1,0} parameter(1)
878 tuple = (f32[3,2]{1,0}, f32[3,2]{1,0}) tuple(param0, param1)
879 ROOT fusion = f32[3,2]{1,0} fusion(tuple), kind=kLoop, calls=fused_computation
880 }
881 )";
882 TF_ASSERT_OK_AND_ASSIGN(auto module,
883 ParseAndReturnVerifiedModule(hlo_string));
884
885 HloInstruction* fusion = module->entry_computation()->root_instruction();
886
887 HloCostAnalysis fusion_analysis(ShapeSize);
888 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
889
890 EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 3 * 2 * 3);
891 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
892 sizeof(float) * 3 * 2 * 2);
893 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
894 sizeof(float) * 3 * 2);
895 }
896
TEST_F(FusionCostAnalysis,TupleBytesAccessed)897 TEST_F(FusionCostAnalysis, TupleBytesAccessed) {
898 absl::string_view hlo_string = R"(
899 HloModule module, is_scheduled=true
900
901 fused_computation {
902 param = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
903 gte0 = f32[2,2]{1,0} get-tuple-element(param), index=0
904 gte1 = f32[2,2]{1,0} get-tuple-element(param), index=1
905 add = f32[2,2]{1,0} add(gte0, gte1)
906 mul = f32[2,2]{1,0} multiply(gte0, gte1)
907 ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul)
908 }
909
910 ENTRY entry {
911 param0 = f32[2,2]{1,0} parameter(0)
912 param1 = f32[2,2]{1,0} parameter(1)
913 tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(param0, param1)
914 ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(tuple), kind=kLoop, calls=fused_computation
915 }
916 )";
917 TF_ASSERT_OK_AND_ASSIGN(auto module,
918 ParseAndReturnVerifiedModule(hlo_string));
919
920 HloInstruction* fusion = module->entry_computation()->root_instruction();
921
922 HloCostAnalysis fusion_analysis(ShapeSize);
923 ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
924
925 EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4);
926 EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0),
927 sizeof(float) * 2 * 2 * 2);
928 EXPECT_EQ(fusion_analysis.output_bytes_accessed(*fusion),
929 sizeof(float) * 2 * 2 * 2);
930 }
931
TEST_F(FusionCostAnalysis,InfeedOutfeed)932 TEST_F(FusionCostAnalysis, InfeedOutfeed) {
933 absl::string_view hlo_string = R"(
934 HloModule module, is_scheduled=true
935
936 ENTRY entry {
937 after-all = token[] after-all()
938 infeed = ((f32[2,3]{1,0}), token[]) infeed(after-all)
939 gte0 = (f32[2,3]{1,0}) get-tuple-element(infeed), index=0
940 gte1 = f32[2,3]{1,0} get-tuple-element(gte0), index=0
941 add = f32[2,3]{1,0} add(gte1, gte1)
942 tuple = (f32[2,3]{1,0}) tuple(add)
943 tok = token[] get-tuple-element(infeed), index=1
944 ROOT outfeed = token[] outfeed(tuple, tok)
945 }
946 )";
947 TF_ASSERT_OK_AND_ASSIGN(auto module,
948 ParseAndReturnVerifiedModule(hlo_string));
949
950 HloInstruction* infeed =
951 module->entry_computation()->GetInstructionWithName("infeed");
952 HloInstruction* outfeed =
953 module->entry_computation()->GetInstructionWithName("outfeed");
954
955 HloCostAnalysis analysis(ShapeSize);
956 ASSERT_IS_OK(infeed->Accept(&analysis));
957 ASSERT_IS_OK(outfeed->Accept(&analysis));
958
959 EXPECT_EQ(analysis.bytes_accessed(*infeed), sizeof(float) * 2 * 3);
960 EXPECT_EQ(analysis.operand_bytes_accessed(*infeed, 0), 0);
961 EXPECT_EQ(analysis.output_bytes_accessed(*infeed), sizeof(float) * 2 * 3);
962
963 EXPECT_EQ(analysis.bytes_accessed(*outfeed), sizeof(float) * 2 * 3);
964 EXPECT_EQ(analysis.operand_bytes_accessed(*outfeed, 0),
965 sizeof(float) * 2 * 3);
966 EXPECT_EQ(analysis.output_bytes_accessed(*outfeed), 0);
967 }
968
TEST_F(FusionCostAnalysis,AllReduceTupleBytesAccessed)969 TEST_F(FusionCostAnalysis, AllReduceTupleBytesAccessed) {
970 absl::string_view hlo_string = R"(
971 HloModule module, is_scheduled=true
972
973 sum {
974 lhs = f32[] parameter(0)
975 rhs = f32[] parameter(1)
976 ROOT add = f32[] add(lhs, rhs)
977 }
978
979 ENTRY entry {
980 param0 = f32[2,2]{1,0} parameter(0)
981 param1 = f32[2,2]{1,0} parameter(1)
982 ROOT all-reduce = (f32[2,2]{1,0}, f32[2,2]{1,0}) all-reduce(param0, param1), replica_groups={{0,1}}, to_apply=sum
983 }
984 )";
985 TF_ASSERT_OK_AND_ASSIGN(auto module,
986 ParseAndReturnVerifiedModule(hlo_string));
987
988 HloInstruction* all_reduce = module->entry_computation()->root_instruction();
989
990 HloCostAnalysis all_reduce_analysis(ShapeSize);
991 ASSERT_IS_OK(all_reduce->Accept(&all_reduce_analysis));
992
993 EXPECT_EQ(all_reduce_analysis.bytes_accessed(*all_reduce),
994 sizeof(float) * 2 * 2 * 4);
995 EXPECT_EQ(all_reduce_analysis.operand_bytes_accessed(*all_reduce, 0),
996 sizeof(float) * 2 * 2);
997 EXPECT_EQ(all_reduce_analysis.operand_bytes_accessed(*all_reduce, 1),
998 sizeof(float) * 2 * 2);
999 EXPECT_EQ(all_reduce_analysis.output_bytes_accessed(*all_reduce),
1000 sizeof(float) * 2 * 2 * 2);
1001 }
1002
TEST_F(HloCostAnalysisTest,TupleCost)1003 TEST_F(HloCostAnalysisTest, TupleCost) {
1004 HloCostAnalysis analysis(ShapeSize);
1005
1006 XlaBuilder builder("tuple");
1007 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
1008 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
1009 Tuple(&builder, {x, y});
1010 auto hlo_module = BuildHloGraph(&builder);
1011
1012 ASSERT_IS_OK(
1013 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1014
1015 EXPECT_EQ(analysis.flop_count(), 0);
1016 EXPECT_EQ(analysis.transcendental_count(), 0);
1017 EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
1018
1019 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1020 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
1021 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), 0);
1022 EXPECT_EQ(analysis.output_bytes_accessed(*root), kPointerSize * 2);
1023 }
1024
1025 using DomainCostAnalysis = HloTestBase;
TEST_F(DomainCostAnalysis,DomainCost)1026 TEST_F(DomainCostAnalysis, DomainCost) {
1027 HloCostAnalysis analysis(ShapeSize);
1028
1029 HloComputation::Builder builder("domain");
1030 auto x = builder.AddInstruction(HloInstruction::CreateParameter(
1031 0, ShapeUtil::MakeShape(F32, {123}), "x"));
1032 auto y = builder.AddInstruction(
1033 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
1034 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
1035 auto domain = builder.AddInstruction(
1036 HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
1037
1038 auto hlo_module = CreateNewVerifiedModule();
1039 hlo_module->AddEntryComputation(builder.Build());
1040
1041 EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
1042 ASSERT_IS_OK(domain->Accept(&analysis));
1043
1044 EXPECT_EQ(analysis.flop_count(*domain), 0);
1045 EXPECT_EQ(analysis.transcendental_count(*domain), 0);
1046 EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
1047 }
1048
TEST_F(HloCostAnalysisTest,BaseDilatedConvolution)1049 TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
1050 XlaBuilder builder("BaseDilatedConvolution");
1051 auto input = Parameter(
1052 &builder, 0,
1053 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
1054 /*x_dim=*/20}),
1055 "input");
1056 auto kernel = Parameter(
1057 &builder, 1,
1058 ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
1059 /*x_dim=*/3}),
1060 "kernel");
1061
1062 ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
1063 /*padding=*/{{1, 1}, {1, 1}},
1064 /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
1065 XlaBuilder::CreateDefaultConvDimensionNumbers(2));
1066
1067 // Run HLO cost analysis.
1068 auto hlo_module = BuildHloGraph(&builder);
1069 HloCostAnalysis analysis(ShapeSize);
1070 ASSERT_IS_OK(
1071 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1072
1073 EXPECT_EQ(analysis.flop_count(), 1472);
1074 }
1075
TEST_F(HloCostAnalysisTest,Slice)1076 TEST_F(HloCostAnalysisTest, Slice) {
1077 // Test the analysis on a slice.
1078 XlaBuilder builder("slice");
1079 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1080 Slice(x, {0}, {1}, {1});
1081 auto hlo_module = BuildHloGraph(&builder);
1082
1083 // Run HLO cost analysis.
1084 HloCostAnalysis analysis(ShapeSize);
1085 ASSERT_IS_OK(
1086 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1087
1088 EXPECT_EQ(analysis.bytes_accessed(), 8);
1089
1090 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1091 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
1092 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1093 }
1094
TEST_F(HloCostAnalysisTest,DynamicSlice)1095 TEST_F(HloCostAnalysisTest, DynamicSlice) {
1096 // Test the analysis on a slice.
1097 XlaBuilder builder("dynamic-slice");
1098 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1099 DynamicSlice(x, absl::Span<const XlaOp>({ConstantR0<int32_t>(&builder, 1)}),
1100 {1});
1101 auto hlo_module = BuildHloGraph(&builder);
1102
1103 // Run HLO cost analysis.
1104 HloCostAnalysis analysis(ShapeSize);
1105 ASSERT_IS_OK(
1106 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1107
1108 EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
1109
1110 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1111 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float));
1112 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t));
1113 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1114 }
1115
TEST_F(HloCostAnalysisTest,DynamicUpdateSlice)1116 TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
1117 // Test the analysis on a slice.
1118 XlaBuilder builder("dynamic-update-slice");
1119 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
1120 DynamicUpdateSlice(
1121 x, ConstantR1<float>(&builder, {1.0}),
1122 absl::Span<const XlaOp>({ConstantR0<int32_t>(&builder, 1)}));
1123 auto hlo_module = BuildHloGraph(&builder);
1124
1125 // Run HLO cost analysis.
1126 HloCostAnalysis analysis(ShapeSize);
1127 ASSERT_IS_OK(
1128 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1129
1130 EXPECT_EQ(analysis.bytes_accessed(), 8 + 4);
1131
1132 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1133
1134 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0);
1135 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(float));
1136 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(int32_t));
1137 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float));
1138 }
1139
TEST_F(HloCostAnalysisTest,Gather)1140 TEST_F(HloCostAnalysisTest, Gather) {
1141 // Test the analysis on a gather.
1142 XlaBuilder builder("gather");
1143 Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
1144 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1145
1146 auto operand = Parameter(&builder, 0, operand_shape, "operand");
1147 auto indices = Parameter(&builder, 1, indices_shape, "indices");
1148 GatherDimensionNumbers dim_numbers;
1149 dim_numbers.add_offset_dims(1);
1150 dim_numbers.add_collapsed_slice_dims(0);
1151 dim_numbers.add_start_index_map(0);
1152 dim_numbers.set_index_vector_dim(1);
1153 Gather(operand, indices, dim_numbers, {1, 3});
1154
1155 auto hlo_module = BuildHloGraph(&builder);
1156
1157 // Run HLO cost analysis.
1158 HloCostAnalysis analysis(ShapeSize);
1159 ASSERT_IS_OK(
1160 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1161
1162 EXPECT_EQ(analysis.bytes_accessed(), 56);
1163
1164 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1165 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1166 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2);
1167 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
1168 }
1169
TEST_F(HloCostAnalysisTest,Scatter)1170 TEST_F(HloCostAnalysisTest, Scatter) {
1171 // Test the analysis on a scatter.
1172 XlaBuilder builder("scatter");
1173 Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3});
1174 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1175 Shape values_shape = ShapeUtil::MakeShape(F32, {2, 3});
1176
1177 auto operand = Parameter(&builder, 0, operand_shape, "operand");
1178 auto indices = Parameter(&builder, 1, indices_shape, "indices");
1179 auto values = Parameter(&builder, 2, values_shape, "values");
1180 ScatterDimensionNumbers dim_numbers;
1181 dim_numbers.set_index_vector_dim(1);
1182 dim_numbers.add_update_window_dims(1);
1183 dim_numbers.add_inserted_window_dims(0);
1184 dim_numbers.add_scatter_dims_to_operand_dims(0);
1185 Scatter(operand, indices, values, add_, dim_numbers);
1186
1187 auto hlo_module = BuildHloGraph(&builder);
1188
1189 // Run HLO cost analysis.
1190 HloCostAnalysis analysis(ShapeSize);
1191 ASSERT_IS_OK(
1192 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1193
1194 EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3)));
1195
1196 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1197 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1198 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2);
1199 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 2 * 3);
1200 EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3);
1201 }
1202
TEST_F(HloCostAnalysisTest,MultioutputScatter)1203 TEST_F(HloCostAnalysisTest, MultioutputScatter) {
1204 // Test the analysis on a scatter.
1205 XlaBuilder builder("scatter");
1206 Shape operand0_shape = ShapeUtil::MakeShape(F32, {3, 3});
1207 Shape operand1_shape = ShapeUtil::MakeShape(S32, {3, 3});
1208 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
1209 Shape values0_shape = ShapeUtil::MakeShape(F32, {2, 3});
1210 Shape values1_shape = ShapeUtil::MakeShape(S32, {2, 3});
1211
1212 auto operand0 = Parameter(&builder, 0, operand0_shape, "operand0");
1213 auto operand1 = Parameter(&builder, 1, operand1_shape, "operand1");
1214 auto indices = Parameter(&builder, 2, indices_shape, "indices");
1215 auto values0 = Parameter(&builder, 3, values0_shape, "values0");
1216 auto values1 = Parameter(&builder, 4, values1_shape, "values1");
1217 ScatterDimensionNumbers dim_numbers;
1218 dim_numbers.set_index_vector_dim(1);
1219 dim_numbers.add_update_window_dims(1);
1220 dim_numbers.add_inserted_window_dims(0);
1221 dim_numbers.add_scatter_dims_to_operand_dims(0);
1222 auto add = [] {
1223 XlaBuilder builder("add");
1224 auto x0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x0");
1225 auto x1 = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "x1");
1226 auto y0 = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "y0");
1227 auto y1 = Parameter(&builder, 3, ShapeUtil::MakeShape(S32, {}), "y1");
1228 Tuple(&builder, {Add(x0, y0), Add(x1, y1)});
1229 auto computation_status = builder.Build();
1230 TF_CHECK_OK(computation_status.status());
1231 return std::move(computation_status).ValueOrDie();
1232 }();
1233 Scatter({operand0, operand1}, indices, {values0, values1}, add, dim_numbers);
1234
1235 auto hlo_module = BuildHloGraph(&builder);
1236
1237 // Run HLO cost analysis.
1238 HloCostAnalysis analysis(ShapeSize);
1239 ASSERT_IS_OK(
1240 hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
1241
1242 EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * 3 * (2 * 3)));
1243
1244 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
1245 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 2 * 3);
1246 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 2 * 3);
1247 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(int32_t) * 2);
1248 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 3), sizeof(float) * 2 * 3);
1249 EXPECT_EQ(analysis.operand_bytes_accessed(*root, 4), sizeof(int32_t) * 2 * 3);
1250 EXPECT_EQ(analysis.output_bytes_accessed(*root), 2 * sizeof(float) * 2 * 3);
1251 }
1252
1253 } // namespace
1254 } // namespace xla
1255