1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14
15 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
16
17 #include <cmath>
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34
35 namespace xla {
36
37 constexpr const char HloCostAnalysis::kFlopsKey[];
38 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
39 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
40 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
41
HloCostAnalysis(const Options & options)42 HloCostAnalysis::HloCostAnalysis(const Options& options) : options_(options) {}
HloCostAnalysis(ShapeSizeFunction shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(ShapeSizeFunction shape_size,
44 const Properties& per_second_rates)
45 : HloCostAnalysis(Options{shape_size, per_second_rates}) {}
46
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48 // Set current instruction cost values to reasonable default values. Each
49 // handler can overwrite these values. In Postprocess, these values are
50 // accumulated and written to the per-instruction maps.
51 current_properties_.clear();
52 current_should_compute_bottleneck_time_ = true;
53
54 // The default number of bytes accessed for an instruction is the sum of the
55 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56 // handle opaque types.
57 float bytes_accessed = GetShapeSize(hlo->shape());
58 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
60 const HloInstruction* operand = hlo->operand(i);
61 bytes_accessed += GetShapeSize(operand->shape());
62 SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63 }
64 current_properties_[kBytesAccessedKey] = bytes_accessed;
65
66 return OkStatus();
67 }
68
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70 if (current_should_compute_bottleneck_time_) {
71 // Compute the time as the time of the bottleneck, i.e. the slowest property
72 // given the per-second rate of each property.
73 float optimal_seconds = 0.0f;
74 for (const auto& property : current_properties_) {
75 if (property.first != kOptimalSecondsKey) {
76 optimal_seconds = std::max(
77 optimal_seconds,
78 property.second / GetProperty(property.first,
79 options_.per_second_rates, INFINITY));
80 }
81 }
82 current_properties_[kOptimalSecondsKey] = optimal_seconds;
83 }
84
85 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86 for (const auto& property : current_properties_) {
87 properties_sum_[property.first] += property.second;
88 }
89
90 return OkStatus();
91 }
92
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94 const HloInstruction* hlo_instruction) {
95 const auto& shape = hlo_instruction->shape();
96 // For element-wise operations, the number of computations is the same as the
97 // number of elements in the output shape.
98 auto computation_count = ShapeUtil::ElementsIn(shape);
99 auto opcode = hlo_instruction->opcode();
100 // We treat transcendental operations separately since one transcendental
101 // operation can correspond to several floating point ops.
102 // kLogistic is included in "trascendental" as it is implemented using
103 // trascendental ops (tanh or exp).
104 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105 opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106 opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107 opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108 opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109 opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110 opcode == HloOpcode::kAtan2) {
111 current_properties_[kTranscendentalsKey] = computation_count;
112 } else {
113 // Note: transcendental operations are considered a separate category from
114 // FLOPs.
115 current_properties_[kFlopsKey] = computation_count;
116 }
117 return OkStatus();
118 }
119
GetProperty(absl::string_view key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(absl::string_view key,
121 const Properties& properties,
122 const float default_value) {
123 auto key_value = properties.find(key);
124 return key_value == properties.end() ? default_value : key_value->second;
125 }
126
GetPropertyForHlo(const HloInstruction & hlo,const std::string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128 const HloInstruction& hlo, const std::string& key,
129 const HloToProperties& hlo_to_properties) {
130 auto it = hlo_to_properties.find(&hlo);
131 if (it == hlo_to_properties.end()) {
132 return 0.0f;
133 } else {
134 return GetProperty(key, it->second);
135 }
136 }
137
GetShapeSize(const Shape & shape) const138 int64_t HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139 if (!LayoutUtil::HasLayout(shape)) {
140 return 0;
141 }
142 return options_.shape_size(shape);
143 }
144
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64_t HloCostAnalysis::FusionParameterReadBytes(
146 const HloInstruction* hlo) const {
147 int64_t size = 0;
148 bool seen_trivial_user = false;
149 CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150 hlo->opcode() == HloOpcode::kGetTupleElement));
151 for (const HloInstruction* user : hlo->users()) {
152 switch (user->opcode()) {
153 case HloOpcode::kFusion: {
154 for (int64_t idx : user->OperandIndices(hlo)) {
155 size += FusionParameterReadBytes(user->fused_parameter(idx));
156 }
157 break;
158 }
159 case HloOpcode::kSlice:
160 size += GetShapeSize(user->shape());
161 break;
162 case HloOpcode::kDynamicSlice:
163 if (hlo == user->operand(0)) {
164 size += GetShapeSize(user->shape());
165 } else if (!seen_trivial_user) {
166 seen_trivial_user = true;
167 size += GetShapeSize(hlo->shape());
168 }
169 break;
170 case HloOpcode::kDynamicUpdateSlice:
171 // Operand 0 is aliased to the output.
172 if (hlo != user->operand(0) && !seen_trivial_user) {
173 seen_trivial_user = true;
174 size += GetShapeSize(hlo->shape());
175 }
176 break;
177 case HloOpcode::kBroadcast:
178 case HloOpcode::kReshape:
179 size += GetShapeSize(hlo->shape());
180 break;
181 default:
182 // Other instructions reading this parameter are assumed to be able to
183 // share the read from memory.
184 if (!seen_trivial_user) {
185 seen_trivial_user = true;
186 size += GetShapeSize(hlo->shape());
187 }
188 }
189 }
190 return size;
191 }
192
HandleElementwiseUnary(const HloInstruction * hlo)193 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
194 return HandleElementwiseOp(hlo);
195 }
196
HandleElementwiseBinary(const HloInstruction * hlo)197 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
198 return HandleElementwiseOp(hlo);
199 }
200
HandleCompare(const HloInstruction * compare)201 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
202 return HandleElementwiseOp(compare);
203 }
204
HandleClamp(const HloInstruction * clamp)205 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
206 return HandleElementwiseOp(clamp);
207 }
208
HandleReducePrecision(const HloInstruction * hlo)209 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
210 return HandleElementwiseOp(hlo);
211 }
212
HandleParameter(const HloInstruction *)213 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
214 current_should_compute_bottleneck_time_ = false;
215 current_properties_[kBytesAccessedKey] = 0;
216 SetOutputBytesAccessed(0);
217 current_properties_[kOptimalSecondsKey] = 0;
218 return OkStatus();
219 }
220
HandleConstant(const HloInstruction *)221 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
222 current_should_compute_bottleneck_time_ = false;
223 current_properties_[kBytesAccessedKey] = 0;
224 SetOutputBytesAccessed(0);
225 current_properties_[kOptimalSecondsKey] = 0;
226 return OkStatus();
227 }
228
HandleIota(const HloInstruction *)229 Status HloCostAnalysis::HandleIota(const HloInstruction*) { return OkStatus(); }
230
HandleGetTupleElement(const HloInstruction * get_tuple_element)231 Status HloCostAnalysis::HandleGetTupleElement(
232 const HloInstruction* get_tuple_element) {
233 // GetTupleElement forwards a pointer and does not touch each element in the
234 // output.
235 current_should_compute_bottleneck_time_ = false;
236 current_properties_[kBytesAccessedKey] = 0;
237 SetOutputBytesAccessed(0);
238 SetOperandBytesAccessed(0, 0);
239 current_properties_[kOptimalSecondsKey] = 0;
240 return OkStatus();
241 }
242
HandleSelect(const HloInstruction * hlo)243 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
244 return HandleElementwiseOp(hlo);
245 }
246
HandleReverse(const HloInstruction *)247 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
248 return OkStatus();
249 }
250
HandleSlice(const HloInstruction * slice)251 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
252 current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
253 SetOutputBytesAccessed(GetShapeSize(slice->shape()));
254 SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
255 return OkStatus();
256 }
257
HandleDynamicSlice(const HloInstruction * dynamic_slice)258 Status HloCostAnalysis::HandleDynamicSlice(
259 const HloInstruction* dynamic_slice) {
260 current_properties_[kBytesAccessedKey] =
261 GetShapeSize(dynamic_slice->shape()) * 2 +
262 GetShapeSize(dynamic_slice->operand(1)->shape());
263 SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
264 SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
265 SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
266 return OkStatus();
267 }
268
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)269 Status HloCostAnalysis::HandleDynamicUpdateSlice(
270 const HloInstruction* dynamic_update_slice) {
271 current_properties_[kBytesAccessedKey] =
272 GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
273 GetShapeSize(dynamic_update_slice->operand(2)->shape());
274 // Operand 0 aliases with the output.
275 SetOutputBytesAccessed(
276 GetShapeSize(dynamic_update_slice->operand(1)->shape()));
277 SetOperandBytesAccessed(0, 0);
278 SetOperandBytesAccessed(
279 1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
280 SetOperandBytesAccessed(
281 2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
282 return OkStatus();
283 }
284
HandleTuple(const HloInstruction * tuple)285 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
286 // The tuple instruction only gathers pointers from inputs (it doesn't iterate
287 // through them). The memory touched is then only the size of the output
288 // index table of the tuple.
289
290 current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
291 SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
292 for (int i = 0; i < tuple->operand_count(); ++i) {
293 SetOperandBytesAccessed(i, 0);
294 }
295 return OkStatus();
296 }
297
HandleConcatenate(const HloInstruction *)298 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
299 return OkStatus();
300 }
301
HandleConvert(const HloInstruction * convert)302 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
303 return HandleElementwiseOp(convert);
304 }
305
HandleCopy(const HloInstruction *)306 Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return OkStatus(); }
307
HandleDomain(const HloInstruction * domain)308 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
309 // Domain does not have any computation or data transfer.
310 current_should_compute_bottleneck_time_ = false;
311 current_properties_[kBytesAccessedKey] = 0;
312 SetOutputBytesAccessed(0);
313 for (int i = 0; i < domain->operand_count(); ++i) {
314 SetOperandBytesAccessed(i, 0);
315 }
316 current_properties_[kOptimalSecondsKey] = 0;
317 return OkStatus();
318 }
319
320 /* static */
GetDotFlops(const Shape & lhs_shape,const Shape & result_shape,const DotDimensionNumbers & dnums)321 int64_t HloCostAnalysis::GetDotFlops(const Shape& lhs_shape,
322 const Shape& result_shape,
323 const DotDimensionNumbers& dnums) {
324 // Count of elements along the reduction dimensions.
325 int64_t reduction_width = 1;
326 for (auto dim : dnums.lhs_contracting_dimensions()) {
327 reduction_width *= lhs_shape.dimensions(dim);
328 }
329 // Each output element requires reduction_width FMA operations.
330 return kFmaFlops * ShapeUtil::ElementsIn(result_shape) * reduction_width;
331 }
332
HandleDot(const HloInstruction * dot)333 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
334 current_properties_[kFlopsKey] = GetDotFlops(
335 dot->operand(0)->shape(), dot->shape(), dot->dot_dimension_numbers());
336 return OkStatus();
337 }
338
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340 // Count nested infeed output tuples.
341 int64_t size = 0;
342 for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343 size += GetShapeSize(indexed_shape.shape);
344 SetOutputBytesAccessed(indexed_shape.index,
345 GetShapeSize(indexed_shape.shape));
346 }
347 SetOutputBytesAccessed(size);
348 current_properties_[kBytesAccessedKey] = size;
349 return OkStatus();
350 }
351
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353 // Count nested outfeed operand tuples.
354 current_properties_[kBytesAccessedKey] = 0;
355 for (int64_t i = 0; i < outfeed->operand_count(); ++i) {
356 const HloInstruction* operand = outfeed->operand(i);
357 int64_t size = 0;
358 for (const auto& indexed_shape :
359 ShapeUtil::GetLeafShapes(operand->shape())) {
360 size += GetShapeSize(indexed_shape.shape);
361 SetOperandBytesAccessed(i, indexed_shape.index,
362 GetShapeSize(indexed_shape.shape));
363 }
364 SetOperandBytesAccessed(i, size);
365 current_properties_[kBytesAccessedKey] += size;
366 }
367 return OkStatus();
368 }
369
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371 // Compute properties of the mapped function.
372 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373 ProcessSubcomputation(map->to_apply()));
374
375 // Compute the cost of all elements for this Map operation.
376 const int64_t element_count = ShapeUtil::ElementsIn(map->shape());
377 for (const auto& property : sub_properties) {
378 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379 current_properties_[property.first] = property.second * element_count;
380 }
381 }
382 return OkStatus();
383 }
384
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386 HloComputation* function = reduce->to_apply();
387 // Compute the cost of the user function.
388 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389 ProcessSubcomputation(function));
390
391 // Compute the cost of all elements for this Reduce operation.
392 // This counts the number of times the reduction function is applied, so it
393 // does not need to be multiplied by the number of input tensors - that's
394 // already "priced in" by the sub-computation doing more work.
395 auto arg = reduce->operand(0);
396 auto output_shape = reduce->shape().IsArray()
397 ? reduce->shape()
398 : reduce->shape().tuple_shapes(0);
399 int64_t reduction_count =
400 ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401 for (const auto& property : sub_properties) {
402 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403 current_properties_[property.first] = property.second * reduction_count;
404 }
405 }
406 return OkStatus();
407 }
408
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410 const HloInstruction* reduce_window) {
411 const Window& window = reduce_window->window();
412 auto function = reduce_window->to_apply();
413 // Compute the properties of the reduction function.
414 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415 ProcessSubcomputation(function));
416
417 // Compute the cost of all elements for this ReduceWindow operation. For each
418 // output element there are window_size - 1 reductions to perform.
419 int64_t window_element_count = 1;
420 for (const auto& dimension : window.dimensions()) {
421 window_element_count *= dimension.size();
422 }
423
424 const int64_t output_element_count =
425 ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426 ? reduce_window->shape()
427 : reduce_window->shape().tuple_shapes(0));
428 const int64_t reduction_count =
429 (window_element_count - 1) * output_element_count;
430 for (const auto& property : sub_properties) {
431 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432 current_properties_[property.first] = property.second * reduction_count;
433 }
434 }
435 return OkStatus();
436 }
437
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439 const HloInstruction* instruction) {
440 // Compute the properties of the select and scatter function.
441 // Compute the properties of the reduction function.
442 TF_ASSIGN_OR_RETURN(const Properties select_properties,
443 ProcessSubcomputation(instruction->select()));
444 TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445 ProcessSubcomputation(instruction->scatter()));
446
447 // Compute the cost of all elements for this operation. For each scatter
448 // source element there are window_size - 1 select computations to perform and
449 // 1 scatter computation to perform.
450 const auto source = instruction->operand(1);
451 const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452 int64_t window_element_count = 1;
453 for (const auto& dimension : instruction->window().dimensions()) {
454 window_element_count *= dimension.size();
455 }
456 const int64_t select_count =
457 source_element_count * (window_element_count - 1);
458 for (const auto& property : select_properties) {
459 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
460 current_properties_[property.first] += property.second * select_count;
461 }
462 }
463 for (const auto& property : scatter_properties) {
464 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
465 current_properties_[property.first] +=
466 property.second * source_element_count;
467 }
468 }
469 return OkStatus();
470 }
471
HandleBitcast(const HloInstruction *)472 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
473 // A bitcast does no computation and touches no memory.
474 current_properties_[kBytesAccessedKey] = 0;
475 SetOutputBytesAccessed(0);
476 SetOperandBytesAccessed(0, 0);
477 current_properties_[kOptimalSecondsKey] = 0;
478 return OkStatus();
479 }
480
HandleBroadcast(const HloInstruction *)481 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
482 return OkStatus();
483 }
484
HandlePad(const HloInstruction *)485 Status HloCostAnalysis::HandlePad(const HloInstruction*) { return OkStatus(); }
486
HandleAsyncStart(const HloInstruction * async_start)487 Status HloCostAnalysis::HandleAsyncStart(const HloInstruction* async_start) {
488 TF_ASSIGN_OR_RETURN(
489 current_properties_,
490 ProcessSubcomputation(async_start->called_computations()[0]));
491 return OkStatus();
492 }
493
HandleAsyncUpdate(const HloInstruction *)494 Status HloCostAnalysis::HandleAsyncUpdate(const HloInstruction*) {
495 return OkStatus();
496 }
497
HandleAsyncDone(const HloInstruction *)498 Status HloCostAnalysis::HandleAsyncDone(const HloInstruction*) {
499 return OkStatus();
500 }
501
HandleCopyStart(const HloInstruction *)502 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
503 return OkStatus();
504 }
505
HandleCopyDone(const HloInstruction *)506 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
507 return OkStatus();
508 }
509
HandleSend(const HloInstruction *)510 Status HloCostAnalysis::HandleSend(const HloInstruction*) { return OkStatus(); }
511
HandleSendDone(const HloInstruction *)512 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
513 return OkStatus();
514 }
515
HandleRecv(const HloInstruction *)516 Status HloCostAnalysis::HandleRecv(const HloInstruction*) { return OkStatus(); }
517
HandleRecvDone(const HloInstruction *)518 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
519 return OkStatus();
520 }
521
HandleReshape(const HloInstruction *)522 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
523 return OkStatus();
524 }
525
HandleDynamicReshape(const HloInstruction *)526 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
527 return OkStatus();
528 }
529
HandleBatchNormTraining(const HloInstruction *)530 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
531 // TODO(b/62294698): Implement cost analysis for batch-norm-training.
532 return OkStatus();
533 }
534
HandleBatchNormInference(const HloInstruction *)535 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
536 // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
537 return OkStatus();
538 }
539
HandleBatchNormGrad(const HloInstruction *)540 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
541 // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
542 return OkStatus();
543 }
544
HandleTranspose(const HloInstruction * transpose)545 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
546 if (transpose->IsEffectiveBitcast()) {
547 return HandleBitcast(transpose);
548 }
549 return OkStatus();
550 }
551
HandleAfterAll(const HloInstruction * token)552 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
553 // This instruction is used to enforce ordering at compile time. No code is
554 // emitted.
555 current_should_compute_bottleneck_time_ = false;
556 current_properties_[kBytesAccessedKey] = 0;
557 SetOutputBytesAccessed(0);
558 for (int i = 0; i < token->operand_count(); ++i) {
559 SetOperandBytesAccessed(i, 0);
560 }
561 current_properties_[kOptimalSecondsKey] = 0;
562 return OkStatus();
563 }
564
HandleAddDependency(const HloInstruction * add_dependency)565 Status HloCostAnalysis::HandleAddDependency(
566 const HloInstruction* add_dependency) {
567 // This instruction is used to enforce ordering at compile time. No code is
568 // emitted.
569 current_should_compute_bottleneck_time_ = false;
570 current_properties_[kBytesAccessedKey] = 0;
571 SetOutputBytesAccessed(0);
572 for (int i = 0; i < add_dependency->operand_count(); ++i) {
573 SetOperandBytesAccessed(i, 0);
574 }
575 current_properties_[kOptimalSecondsKey] = 0;
576 return OkStatus();
577 }
578
GetConvolutionFlops(const HloInstruction * convolution)579 int64_t HloCostAnalysis::GetConvolutionFlops(
580 const HloInstruction* convolution) {
581 auto lhs = convolution->operand(0);
582 auto rhs = convolution->operand(1);
583 const Shape& lhs_shape = lhs->shape();
584 const Shape& rhs_shape = rhs->shape();
585 const Shape& result_shape = convolution->shape();
586
587 return GetConvolutionFlops(convolution, lhs_shape, rhs_shape, result_shape);
588 }
589
590 /* static */
GetConvolutionFlops(const HloInstruction * convolution,const Shape & lhs_shape,const Shape & rhs_shape,const Shape & result_shape)591 int64_t HloCostAnalysis::GetConvolutionFlops(const HloInstruction* convolution,
592 const Shape& lhs_shape,
593 const Shape& rhs_shape,
594 const Shape& result_shape) {
595 Window window = convolution->window();
596 const auto& dnums = convolution->convolution_dimension_numbers();
597 const int64_t input_batch_dim = dnums.input_batch_dimension();
598 const int64_t input_feature_dim = dnums.input_feature_dimension();
599 const int64_t output_feature_dim = dnums.output_feature_dimension();
600 const int64_t input_feature =
601 ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
602 const int64_t output_feature =
603 ShapeUtil::GetDimension(result_shape, output_feature_dim);
604 const int64_t batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
605
606 DimensionVector kernel_limits;
607 DimensionVector output_limits;
608 DimensionVector input_limits;
609 if (window.dimensions().empty()) {
610 window = window_util::MakeWindow({1});
611 kernel_limits.push_back(1);
612 output_limits.push_back(1);
613 input_limits.push_back(1);
614 } else {
615 for (int64_t spatial_dimension = 0;
616 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
617 // Spatial dimension number for kernel (rhs).
618 const int64_t kernel_spatial_dim =
619 dnums.kernel_spatial_dimensions(spatial_dimension);
620 const int64_t kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
621 kernel_limits.push_back(kernel_limit);
622
623 // Spatial dimension number for output.
624 const int64_t output_spatial_dim =
625 dnums.output_spatial_dimensions(spatial_dimension);
626 const int64_t output_limit = result_shape.dimensions(output_spatial_dim);
627 output_limits.push_back(output_limit);
628
629 // Spatial dimension number for input (lhs).
630 const int64_t input_spatial_dim =
631 dnums.input_spatial_dimensions(spatial_dimension);
632 const int64_t input_limit = lhs_shape.dimensions(input_spatial_dim);
633 input_limits.push_back(input_limit);
634 }
635 }
636
637 DimensionVector valid_position_counts;
638
639 // Loop over each spatial dimension.
640 for (int64_t spatial_dimension = 0;
641 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
642 const auto& window_dim = window.dimensions(spatial_dimension);
643 // These two conditions will create an N^2 iteration pattern with only N
644 // valid elements. This is a performance optimization and produces the same
645 // result as the whole loop.
646 if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
647 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
648 input_limits[spatial_dimension] == window_dim.base_dilation() &&
649 window_dim.window_dilation() == 1 &&
650 std::max<int64_t>(1, input_limits[spatial_dimension] - 1) ==
651 window_dim.stride() &&
652 window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
653 valid_position_counts.push_back(input_limits[spatial_dimension]);
654 continue;
655 }
656
657 if (input_limits[spatial_dimension] == 1 &&
658 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
659 window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
660 window_dim.stride() == 1 &&
661 window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
662 window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
663 valid_position_counts.push_back(output_limits[spatial_dimension]);
664 continue;
665 }
666
667 int64_t valid_position_count = 0;
668 // Loop over each point in the kernel.
669 for (int64_t kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
670 ++kernel_idx) {
671 // Loop over each point in the output.
672 for (int64_t output_idx = 0;
673 output_idx < output_limits[spatial_dimension]; ++output_idx) {
674 // Calculate lhs (input) index without taking base dilation into
675 // account.
676 const int64_t undilated_index =
677 output_idx * window_dim.stride() - window_dim.padding_low() +
678 kernel_idx * window_dim.window_dilation();
679
680 // Calculate the actual lhs (input) index after dilation. Avoid the
681 // division as an optimization.
682 const int64_t lhs_spatial_index =
683 window_dim.base_dilation() > 1
684 ? undilated_index / window_dim.base_dilation()
685 : undilated_index;
686
687 // Skip if the lhs (input) index is to be dilated.
688 if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
689 continue;
690 }
691
692 // Skip if input index is not in bound.
693 if (lhs_spatial_index < 0 ||
694 lhs_spatial_index >= input_limits[spatial_dimension]) {
695 continue;
696 }
697
698 valid_position_count += 1;
699 }
700 }
701 valid_position_counts.push_back(valid_position_count);
702 }
703
704 const int64_t fma_count =
705 (input_feature / convolution->feature_group_count()) * output_feature *
706 (batch / convolution->batch_group_count()) *
707 Product(valid_position_counts);
708 return fma_count * kFmaFlops;
709 }
710
HandleConvolution(const HloInstruction * convolution)711 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
712 current_properties_[kFlopsKey] = GetConvolutionFlops(convolution);
713 return OkStatus();
714 }
715
HandleFft(const HloInstruction * fft)716 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
717 auto real_shape =
718 fft->operand(0)->shape().IsTuple()
719 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
720 : fft->operand(0)->shape();
721 constexpr int kFmaPerComplexMul = 4;
722 int64_t log_factors = 1;
723 for (int64_t dim : fft->fft_length()) {
724 log_factors *= Log2Floor<uint64_t>(dim);
725 }
726 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
727 ShapeUtil::ElementsIn(real_shape);
728 return OkStatus();
729 }
730
HandleTriangularSolve(const HloInstruction * hlo)731 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
732 // Half of operand 0 is read.
733 float bytes_accessed = GetShapeSize(hlo->shape());
734 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
735 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
736 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
737 bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
738 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
739 current_properties_[kBytesAccessedKey] = bytes_accessed;
740
741 const Shape& a_shape = hlo->operand(0)->shape();
742 const Shape& b_shape = hlo->operand(1)->shape();
743 // Estimate as batch * mn^2 / 2 flops.
744 int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
745 elems *= ShapeUtil::ElementsIn(b_shape);
746 current_properties_[kFlopsKey] = kFmaFlops * elems;
747 return OkStatus();
748 }
749
HandleCholesky(const HloInstruction * hlo)750 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
751 // Half of operand 0 is read and half of the output will be written.
752 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
753 SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
754 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
755 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
756 current_properties_[kBytesAccessedKey] = bytes_accessed;
757
758 const Shape& a_shape = hlo->operand(0)->shape();
759 // Estimate as batch * n^3 / 3 flops.
760 int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
761 elems *= ShapeUtil::ElementsIn(a_shape);
762 current_properties_[kFlopsKey] = elems / 3;
763 return OkStatus();
764 }
765
HandleOptimizationBarrier(const HloInstruction *)766 Status HloCostAnalysis::HandleOptimizationBarrier(
767 const HloInstruction* /*hlo*/) {
768 return OkStatus();
769 }
770
HandleAllGather(const HloInstruction *)771 Status HloCostAnalysis::HandleAllGather(const HloInstruction* /*hlo*/) {
772 return OkStatus();
773 }
774
HandleAllGatherStart(const HloInstruction * hlo)775 Status HloCostAnalysis::HandleAllGatherStart(const HloInstruction* hlo) {
776 return HandleAllGather(hlo);
777 }
778
HandleAllGatherDone(const HloInstruction *)779 Status HloCostAnalysis::HandleAllGatherDone(const HloInstruction* /*hlo*/) {
780 return OkStatus();
781 }
782
HandleAllReduce(const HloInstruction * crs)783 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
784 // We assume 2 replicas, so that each output element is the sum of two input
785 // elements.
786 //
787 // TODO(b/33004697): Compute correct cost here, taking the actual number of
788 // replicas into account.
789 double flops = 0.0;
790 int64_t output_bytes_accessed = 0;
791 ShapeUtil::ForEachSubshape(
792 crs->shape(), [&](const Shape& subshape, const ShapeIndex&) {
793 if (subshape.IsArray()) {
794 flops += ShapeUtil::ElementsIn(subshape);
795 output_bytes_accessed += GetShapeSize(subshape);
796 }
797 });
798 int64_t bytes_accessed = output_bytes_accessed;
799 for (const HloInstruction* operand : crs->operands()) {
800 bytes_accessed += GetShapeSize(operand->shape());
801 }
802 current_properties_[kFlopsKey] = flops;
803 SetOutputBytesAccessed(output_bytes_accessed);
804 current_properties_[kBytesAccessedKey] = bytes_accessed;
805 return OkStatus();
806 }
807
HandleReduceScatter(const HloInstruction * hlo)808 Status HloCostAnalysis::HandleReduceScatter(const HloInstruction* hlo) {
809 return OkStatus();
810 }
811
HandleAllReduceStart(const HloInstruction * hlo)812 Status HloCostAnalysis::HandleAllReduceStart(const HloInstruction* hlo) {
813 return HandleAllReduce(hlo);
814 }
815
HandleAllReduceDone(const HloInstruction *)816 Status HloCostAnalysis::HandleAllReduceDone(const HloInstruction* /*hlo*/) {
817 return OkStatus();
818 }
819
HandleAllToAll(const HloInstruction * hlo)820 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
821 return OkStatus();
822 }
823
HandleCollectivePermute(const HloInstruction *)824 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
825 return OkStatus();
826 }
827
HandleCollectivePermuteStart(const HloInstruction *)828 Status HloCostAnalysis::HandleCollectivePermuteStart(
829 const HloInstruction* /*hlo*/) {
830 return OkStatus();
831 }
832
HandleCollectivePermuteDone(const HloInstruction *)833 Status HloCostAnalysis::HandleCollectivePermuteDone(
834 const HloInstruction* /*hlo*/) {
835 return OkStatus();
836 }
837
HandlePartitionId(const HloInstruction *)838 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
839 return OkStatus();
840 }
841
HandleReplicaId(const HloInstruction *)842 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
843 return OkStatus();
844 }
845
HandleRng(const HloInstruction * random)846 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
847 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
848 // cost changes with the implementation and the distribution. For now, assume
849 // the cost of each RNG is same as a transcendental operation.
850 current_properties_[kTranscendentalsKey] =
851 ShapeUtil::ElementsIn(random->shape());
852 return OkStatus();
853 }
854
HandleRngBitGenerator(const HloInstruction * random)855 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
856 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
857 // cost changes with the implementation and the distribution. For now, assume
858 // the cost of each RNG is same as a transcendental operation.
859 current_properties_[kTranscendentalsKey] =
860 ShapeUtil::ElementsInRecursive(random->shape());
861 return OkStatus();
862 }
863
HandleRngGetAndUpdateState(const HloInstruction * random)864 Status HloCostAnalysis::HandleRngGetAndUpdateState(
865 const HloInstruction* random) {
866 return OkStatus();
867 }
868
HandleFusion(const HloInstruction * fusion)869 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
870 if (fusion->IsCustomFusion()) {
871 for (const HloInstruction* hlo :
872 fusion->fused_instructions_computation()->instructions()) {
873 if (hlo->opcode() == HloOpcode::kGather) {
874 return HandleGather(hlo);
875 }
876 if (hlo->opcode() == HloOpcode::kScatter) {
877 return HandleScatter(hlo);
878 }
879 }
880 }
881 TF_ASSIGN_OR_RETURN(
882 current_properties_,
883 ProcessSubcomputation(fusion->fused_instructions_computation()));
884
885 // Fusion nodes that produce a tuple also produce the entries in the tuple.
886 // Ignore the memory accessed inside fused ops, since fusion is supposed to
887 // prevent intermediate data from touching slow memory.
888 current_properties_[kBytesAccessedKey] = 0;
889 ShapeUtil::ForEachSubshape(
890 fusion->shape(),
891 [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
892 if (!subshape.IsArray()) {
893 return;
894 }
895 if (shape_index.empty()) {
896 if (fusion->fused_expression_root()->opcode() ==
897 HloOpcode::kDynamicUpdateSlice) {
898 int64_t size = GetShapeSize(
899 fusion->fused_expression_root()->operand(1)->shape());
900 current_properties_[kBytesAccessedKey] += size;
901 SetOutputBytesAccessed(shape_index, size);
902 return;
903 }
904 } else if (shape_index.size() == 1) {
905 if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
906 fusion->fused_expression_root()
907 ->operand(shape_index[0])
908 ->opcode() == HloOpcode::kDynamicUpdateSlice) {
909 int64_t size = GetShapeSize(fusion->fused_expression_root()
910 ->operand(shape_index[0])
911 ->operand(1)
912 ->shape());
913 current_properties_[kBytesAccessedKey] += size;
914 SetOutputBytesAccessed(shape_index, size);
915 return;
916 }
917 }
918 current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
919 SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
920 });
921
922 if (fusion->shape().IsTuple()) {
923 // Propagate and accumulate the output tuple bytes from the tuple subshapes.
924 // This ensures we have the correct output bytes accessed for the shape
925 // index
926 // {}.
927 std::function<float(const Shape&, const ShapeIndex&)>
928 propagate_output_size_to_parent;
929 propagate_output_size_to_parent = [&](const Shape& shape,
930 const ShapeIndex& shape_index) {
931 auto output_bytes_it =
932 current_properties_.find(GetOutputBytesAccessedKey(shape_index));
933 if (output_bytes_it != current_properties_.end()) {
934 return output_bytes_it->second;
935 }
936 float bytes_accessed = 0;
937 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
938 const Shape& subshape = shape.tuple_shapes(i);
939 ShapeIndex subshape_index(shape_index);
940 subshape_index.push_back(i);
941 bytes_accessed +=
942 propagate_output_size_to_parent(subshape, subshape_index);
943 }
944 SetOutputBytesAccessed(shape_index, bytes_accessed);
945 return bytes_accessed;
946 };
947 auto output_bytes_it =
948 current_properties_.find(GetOutputBytesAccessedKey());
949 if (output_bytes_it != current_properties_.end()) {
950 current_properties_.erase(output_bytes_it);
951 }
952 propagate_output_size_to_parent(fusion->shape(), {});
953 }
954
955 for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) {
956 const HloInstruction* operand = fusion->fused_parameter(i);
957 int64_t operand_size = 0;
958 if (!operand->shape().IsTuple()) {
959 operand_size = FusionParameterReadBytes(operand);
960 } else {
961 // If the fusion parameter is a tuple type, find the gte for the leaf
962 // shape and calculate the bytes accessed for those array types.
963 for (const auto& indexed_shape :
964 ShapeUtil::GetLeafShapes(operand->shape())) {
965 const HloInstruction* gte = operand;
966 for (int64_t index : indexed_shape.index) {
967 for (const HloInstruction* user : gte->users()) {
968 if (user->opcode() == HloOpcode::kGetTupleElement &&
969 user->tuple_index() == index) {
970 gte = user;
971 break;
972 }
973 }
974 }
975 int64_t size = FusionParameterReadBytes(gte);
976 operand_size += size;
977 SetOperandBytesAccessed(i, indexed_shape.index, size);
978 }
979 }
980 current_properties_[kBytesAccessedKey] += operand_size;
981 SetOperandBytesAccessed(i, operand_size);
982 }
983
984 return OkStatus();
985 }
986
HandleCall(const HloInstruction * call)987 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
988 TF_ASSIGN_OR_RETURN(current_properties_,
989 ProcessSubcomputation(call->to_apply()));
990 current_should_compute_bottleneck_time_ = false;
991 return OkStatus();
992 }
993
HandleCustomCall(const HloInstruction * custom_call)994 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
995 // Mark applicable fields as "unknown", since we don't know what this
996 // CustomCall does. This is better than returning an error, which would stop
997 // iteration, and therefore would prevent us from getting *any* stats for a
998 // computation which contains a CustomCall.
999 current_properties_[kOptimalSecondsKey] = -1;
1000 current_properties_[kBytesAccessedKey] = -1;
1001 SetOutputBytesAccessed(-1);
1002 for (int i = 0; i < custom_call->operand_count(); ++i) {
1003 SetOperandBytesAccessed(i, -1);
1004 }
1005 current_properties_[kFlopsKey] = -1;
1006 current_should_compute_bottleneck_time_ = false;
1007 return OkStatus();
1008 }
1009
HandleSort(const HloInstruction * sort)1010 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
1011 // This assumes a comparison based N*log(N) algorithm. As for all ops, the
1012 // actual properties of the op depend on the backend implementation.
1013 int64_t elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
1014 current_properties_[kFlopsKey] = elements * Log2Ceiling<uint64_t>(elements);
1015 return OkStatus();
1016 }
1017
HandleWhile(const HloInstruction * xla_while)1018 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
1019 // Since the number of iterations of the while node will not always be
1020 // something that we can statically analyze, we cannot precisely compute the
1021 // cost of a while node. For now compute the cost of a single iteration.
1022 TF_ASSIGN_OR_RETURN(const Properties body_properties,
1023 ProcessSubcomputation(xla_while->while_body()));
1024
1025 TF_ASSIGN_OR_RETURN(const Properties condition_properties,
1026 ProcessSubcomputation(xla_while->while_condition()));
1027
1028 current_properties_.clear();
1029 for (const auto& property : body_properties) {
1030 current_properties_[property.first] += property.second;
1031 }
1032 for (const auto& property : condition_properties) {
1033 current_properties_[property.first] += property.second;
1034 }
1035 current_should_compute_bottleneck_time_ = false;
1036
1037 return OkStatus();
1038 }
1039
HandleConditional(const HloInstruction * conditional)1040 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
1041 // Compute the cost of the branch computations and take the maximum from those
1042 // for each property.
1043 TF_ASSIGN_OR_RETURN(
1044 const Properties branch0_computation_properties,
1045 ProcessSubcomputation(conditional->branch_computation(0)));
1046 current_properties_ = branch0_computation_properties;
1047 for (int j = 1; j < conditional->branch_count(); ++j) {
1048 TF_ASSIGN_OR_RETURN(
1049 const Properties branch_computation_properties,
1050 ProcessSubcomputation(conditional->branch_computation(j)));
1051 for (const auto& property : branch_computation_properties) {
1052 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_,
1053 property)) {
1054 auto& current_property = current_properties_[property.first];
1055 current_property = std::max(current_property, property.second);
1056 }
1057 }
1058 }
1059 current_should_compute_bottleneck_time_ = false;
1060
1061 return OkStatus();
1062 }
1063
HandleGather(const HloInstruction * gather)1064 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1065 // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1066 // size of the output shape and a read of the gather indices.
1067 int64_t output_size = GetShapeSize(gather->shape());
1068 current_properties_[kBytesAccessedKey] =
1069 output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1070 SetOperandBytesAccessed(0, output_size);
1071 SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1072 SetOutputBytesAccessed(output_size);
1073 // Gather does not issue any flops.
1074 return OkStatus();
1075 }
1076
HandleScatter(const HloInstruction * hlo)1077 Status HloCostAnalysis::HandleScatter(const HloInstruction* hlo) {
1078 auto* scatter = Cast<HloScatterInstruction>(hlo);
1079 // Scatter accesses the equivalent of 3N update shapes (input, output, and
1080 // updates), and the scatter indices.
1081 int64_t total_update_size = 0;
1082 for (int i = 0, n = scatter->scatter_operand_count(); i < n; ++i) {
1083 int64_t update_size = GetShapeSize(scatter->scatter_updates()[i]->shape());
1084 SetOperandBytesAccessed(i, update_size);
1085 SetOperandBytesAccessed(n + 1 + i, update_size);
1086 total_update_size += update_size;
1087 }
1088 int64_t scatter_indices_size =
1089 GetShapeSize(scatter->scatter_indices()->shape());
1090 SetOperandBytesAccessed(scatter->scatter_operand_count(),
1091 scatter_indices_size);
1092 current_properties_[kBytesAccessedKey] =
1093 total_update_size * 3 + scatter_indices_size;
1094 SetOutputBytesAccessed(total_update_size);
1095 const int64_t element_count =
1096 ShapeUtil::ElementsIn(scatter->scatter_updates()[0]->shape());
1097 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1098 ProcessSubcomputation(scatter->to_apply()));
1099 for (const auto& property : sub_properties) {
1100 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1101 current_properties_[property.first] = property.second * element_count;
1102 }
1103 }
1104 return OkStatus();
1105 }
1106
HandleGetDimensionSize(const HloInstruction *)1107 Status HloCostAnalysis::HandleGetDimensionSize(
1108 const HloInstruction* /*get_size*/) {
1109 return OkStatus();
1110 }
1111
HandleSetDimensionSize(const HloInstruction *)1112 Status HloCostAnalysis::HandleSetDimensionSize(
1113 const HloInstruction* /*set_size*/) {
1114 return OkStatus();
1115 }
1116
FinishVisit(const HloInstruction *)1117 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1118 return OkStatus();
1119 }
1120
flop_count() const1121 float HloCostAnalysis::flop_count() const {
1122 return GetProperty(kFlopsKey, properties_sum_);
1123 }
1124
transcendental_count() const1125 float HloCostAnalysis::transcendental_count() const {
1126 return GetProperty(kTranscendentalsKey, properties_sum_);
1127 }
1128
bytes_accessed() const1129 float HloCostAnalysis::bytes_accessed() const {
1130 return GetProperty(kBytesAccessedKey, properties_sum_);
1131 }
1132
optimal_seconds() const1133 float HloCostAnalysis::optimal_seconds() const {
1134 return GetProperty(kOptimalSecondsKey, properties_sum_);
1135 }
1136
flop_count(const HloInstruction & hlo) const1137 int64_t HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1138 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1139 }
1140
transcendental_count(const HloInstruction & hlo) const1141 int64_t HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1142 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1143 }
1144
bytes_accessed(const HloInstruction & hlo) const1145 int64_t HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1146 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1147 }
1148
operand_bytes_accessed(const HloInstruction & hlo,int64_t operand_num,ShapeIndex index) const1149 int64_t HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1150 int64_t operand_num,
1151 ShapeIndex index) const {
1152 return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1153 hlo_properties_);
1154 }
1155
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1156 int64_t HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1157 ShapeIndex index) const {
1158 return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1159 hlo_properties_);
1160 }
1161
optimal_seconds(const HloInstruction & hlo) const1162 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1163 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1164 }
1165
GetBytesRead(const HloInstruction & hlo,std::optional<int64_t> memory_space) const1166 int64_t HloCostAnalysis::GetBytesRead(
1167 const HloInstruction& hlo, std::optional<int64_t> memory_space) const {
1168 int64_t bytes_read = 0;
1169 for (int operand_number = 0; operand_number < hlo.operand_count();
1170 ++operand_number) {
1171 const Shape& shape = hlo.operand(operand_number)->shape();
1172 ShapeUtil::ForEachSubshape(
1173 shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1174 if (ShapeUtil::IsLeafIndex(shape, index)) {
1175 std::optional<int64_t> index_memory_space;
1176 if (sub_shape.has_layout()) {
1177 index_memory_space = sub_shape.layout().memory_space();
1178 }
1179 if (!memory_space || memory_space == index_memory_space) {
1180 bytes_read += operand_bytes_accessed(hlo, operand_number, index);
1181 }
1182 }
1183 });
1184 }
1185 return bytes_read;
1186 }
1187
GetBytesWritten(const HloInstruction & hlo,std::optional<int64_t> memory_space) const1188 int64_t HloCostAnalysis::GetBytesWritten(
1189 const HloInstruction& hlo, std::optional<int64_t> memory_space) const {
1190 int64_t bytes_written = 0;
1191 for (const ShapeUtil::IndexedShape& indexed_shape :
1192 ShapeUtil::GetLeafShapes(hlo.shape())) {
1193 std::optional<int64_t> index_memory_space;
1194 if (indexed_shape.shape.has_layout()) {
1195 index_memory_space = indexed_shape.shape.layout().memory_space();
1196 }
1197 if (!memory_space || memory_space == index_memory_space) {
1198 bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1199 }
1200 }
1201 return bytes_written;
1202 }
1203
ProcessSubcomputation(HloComputation * computation)1204 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1205 HloComputation* computation) {
1206 auto visitor = CreateNestedCostAnalysis();
1207 visitor->ReserveVisitStates(computation->instruction_count());
1208 TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1209 hlo_properties_.insert(visitor->hlo_properties_.begin(),
1210 visitor->hlo_properties_.end());
1211 return visitor->properties();
1212 }
1213
CreateNestedCostAnalysis()1214 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis() {
1215 return std::make_unique<HloCostAnalysis>(options_);
1216 }
1217
SetOperandBytesAccessed(int64_t operand_num,float value)1218 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1219 float value) {
1220 current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1221 }
1222
SetOperandBytesAccessed(int64_t operand_num,ShapeIndex index,float value)1223 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1224 ShapeIndex index, float value) {
1225 current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1226 value;
1227 }
1228
SetOutputBytesAccessed(float value)1229 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1230 current_properties_[GetOutputBytesAccessedKey()] = value;
1231 }
1232
SetOutputBytesAccessed(ShapeIndex index,float value)1233 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1234 current_properties_[GetOutputBytesAccessedKey(index)] = value;
1235 }
1236
GetOperandBytesAccessedKey(int64_t operand_num,ShapeIndex index)1237 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1238 int64_t operand_num, ShapeIndex index) {
1239 return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1240 index.ToString());
1241 }
1242
GetOutputBytesAccessedKey(ShapeIndex index)1243 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1244 ShapeIndex index) {
1245 return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1246 }
1247
1248 } // namespace xla
1249