xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_cost_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 
15 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
16 
17 #include <cmath>
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 
35 namespace xla {
36 
37 constexpr const char HloCostAnalysis::kFlopsKey[];
38 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
39 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
40 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
41 
HloCostAnalysis(const Options & options)42 HloCostAnalysis::HloCostAnalysis(const Options& options) : options_(options) {}
HloCostAnalysis(ShapeSizeFunction shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(ShapeSizeFunction shape_size,
44                                  const Properties& per_second_rates)
45     : HloCostAnalysis(Options{shape_size, per_second_rates}) {}
46 
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48   // Set current instruction cost values to reasonable default values. Each
49   // handler can overwrite these values. In Postprocess, these values are
50   // accumulated and written to the per-instruction maps.
51   current_properties_.clear();
52   current_should_compute_bottleneck_time_ = true;
53 
54   // The default number of bytes accessed for an instruction is the sum of the
55   // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56   // handle opaque types.
57   float bytes_accessed = GetShapeSize(hlo->shape());
58   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
60     const HloInstruction* operand = hlo->operand(i);
61     bytes_accessed += GetShapeSize(operand->shape());
62     SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63   }
64   current_properties_[kBytesAccessedKey] = bytes_accessed;
65 
66   return OkStatus();
67 }
68 
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70   if (current_should_compute_bottleneck_time_) {
71     // Compute the time as the time of the bottleneck, i.e. the slowest property
72     // given the per-second rate of each property.
73     float optimal_seconds = 0.0f;
74     for (const auto& property : current_properties_) {
75       if (property.first != kOptimalSecondsKey) {
76         optimal_seconds = std::max(
77             optimal_seconds,
78             property.second / GetProperty(property.first,
79                                           options_.per_second_rates, INFINITY));
80       }
81     }
82     current_properties_[kOptimalSecondsKey] = optimal_seconds;
83   }
84 
85   TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86   for (const auto& property : current_properties_) {
87     properties_sum_[property.first] += property.second;
88   }
89 
90   return OkStatus();
91 }
92 
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94     const HloInstruction* hlo_instruction) {
95   const auto& shape = hlo_instruction->shape();
96   // For element-wise operations, the number of computations is the same as the
97   // number of elements in the output shape.
98   auto computation_count = ShapeUtil::ElementsIn(shape);
99   auto opcode = hlo_instruction->opcode();
100   // We treat transcendental operations separately since one transcendental
101   // operation can correspond to several floating point ops.
102   // kLogistic is included in "trascendental" as it is implemented using
103   // trascendental ops (tanh or exp).
104   if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105       opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106       opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107       opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108       opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109       opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110       opcode == HloOpcode::kAtan2) {
111     current_properties_[kTranscendentalsKey] = computation_count;
112   } else {
113     // Note: transcendental operations are considered a separate category from
114     // FLOPs.
115     current_properties_[kFlopsKey] = computation_count;
116   }
117   return OkStatus();
118 }
119 
GetProperty(absl::string_view key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(absl::string_view key,
121                                               const Properties& properties,
122                                               const float default_value) {
123   auto key_value = properties.find(key);
124   return key_value == properties.end() ? default_value : key_value->second;
125 }
126 
GetPropertyForHlo(const HloInstruction & hlo,const std::string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128     const HloInstruction& hlo, const std::string& key,
129     const HloToProperties& hlo_to_properties) {
130   auto it = hlo_to_properties.find(&hlo);
131   if (it == hlo_to_properties.end()) {
132     return 0.0f;
133   } else {
134     return GetProperty(key, it->second);
135   }
136 }
137 
GetShapeSize(const Shape & shape) const138 int64_t HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139   if (!LayoutUtil::HasLayout(shape)) {
140     return 0;
141   }
142   return options_.shape_size(shape);
143 }
144 
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64_t HloCostAnalysis::FusionParameterReadBytes(
146     const HloInstruction* hlo) const {
147   int64_t size = 0;
148   bool seen_trivial_user = false;
149   CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150                            hlo->opcode() == HloOpcode::kGetTupleElement));
151   for (const HloInstruction* user : hlo->users()) {
152     switch (user->opcode()) {
153       case HloOpcode::kFusion: {
154         for (int64_t idx : user->OperandIndices(hlo)) {
155           size += FusionParameterReadBytes(user->fused_parameter(idx));
156         }
157         break;
158       }
159       case HloOpcode::kSlice:
160         size += GetShapeSize(user->shape());
161         break;
162       case HloOpcode::kDynamicSlice:
163         if (hlo == user->operand(0)) {
164           size += GetShapeSize(user->shape());
165         } else if (!seen_trivial_user) {
166           seen_trivial_user = true;
167           size += GetShapeSize(hlo->shape());
168         }
169         break;
170       case HloOpcode::kDynamicUpdateSlice:
171         // Operand 0 is aliased to the output.
172         if (hlo != user->operand(0) && !seen_trivial_user) {
173           seen_trivial_user = true;
174           size += GetShapeSize(hlo->shape());
175         }
176         break;
177       case HloOpcode::kBroadcast:
178       case HloOpcode::kReshape:
179         size += GetShapeSize(hlo->shape());
180         break;
181       default:
182         // Other instructions reading this parameter are assumed to be able to
183         // share the read from memory.
184         if (!seen_trivial_user) {
185           seen_trivial_user = true;
186           size += GetShapeSize(hlo->shape());
187         }
188     }
189   }
190   return size;
191 }
192 
HandleElementwiseUnary(const HloInstruction * hlo)193 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
194   return HandleElementwiseOp(hlo);
195 }
196 
HandleElementwiseBinary(const HloInstruction * hlo)197 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
198   return HandleElementwiseOp(hlo);
199 }
200 
HandleCompare(const HloInstruction * compare)201 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
202   return HandleElementwiseOp(compare);
203 }
204 
HandleClamp(const HloInstruction * clamp)205 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
206   return HandleElementwiseOp(clamp);
207 }
208 
HandleReducePrecision(const HloInstruction * hlo)209 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
210   return HandleElementwiseOp(hlo);
211 }
212 
HandleParameter(const HloInstruction *)213 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
214   current_should_compute_bottleneck_time_ = false;
215   current_properties_[kBytesAccessedKey] = 0;
216   SetOutputBytesAccessed(0);
217   current_properties_[kOptimalSecondsKey] = 0;
218   return OkStatus();
219 }
220 
HandleConstant(const HloInstruction *)221 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
222   current_should_compute_bottleneck_time_ = false;
223   current_properties_[kBytesAccessedKey] = 0;
224   SetOutputBytesAccessed(0);
225   current_properties_[kOptimalSecondsKey] = 0;
226   return OkStatus();
227 }
228 
HandleIota(const HloInstruction *)229 Status HloCostAnalysis::HandleIota(const HloInstruction*) { return OkStatus(); }
230 
HandleGetTupleElement(const HloInstruction * get_tuple_element)231 Status HloCostAnalysis::HandleGetTupleElement(
232     const HloInstruction* get_tuple_element) {
233   // GetTupleElement forwards a pointer and does not touch each element in the
234   // output.
235   current_should_compute_bottleneck_time_ = false;
236   current_properties_[kBytesAccessedKey] = 0;
237   SetOutputBytesAccessed(0);
238   SetOperandBytesAccessed(0, 0);
239   current_properties_[kOptimalSecondsKey] = 0;
240   return OkStatus();
241 }
242 
HandleSelect(const HloInstruction * hlo)243 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
244   return HandleElementwiseOp(hlo);
245 }
246 
HandleReverse(const HloInstruction *)247 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
248   return OkStatus();
249 }
250 
HandleSlice(const HloInstruction * slice)251 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
252   current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
253   SetOutputBytesAccessed(GetShapeSize(slice->shape()));
254   SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
255   return OkStatus();
256 }
257 
HandleDynamicSlice(const HloInstruction * dynamic_slice)258 Status HloCostAnalysis::HandleDynamicSlice(
259     const HloInstruction* dynamic_slice) {
260   current_properties_[kBytesAccessedKey] =
261       GetShapeSize(dynamic_slice->shape()) * 2 +
262       GetShapeSize(dynamic_slice->operand(1)->shape());
263   SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
264   SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
265   SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
266   return OkStatus();
267 }
268 
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)269 Status HloCostAnalysis::HandleDynamicUpdateSlice(
270     const HloInstruction* dynamic_update_slice) {
271   current_properties_[kBytesAccessedKey] =
272       GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
273       GetShapeSize(dynamic_update_slice->operand(2)->shape());
274   // Operand 0 aliases with the output.
275   SetOutputBytesAccessed(
276       GetShapeSize(dynamic_update_slice->operand(1)->shape()));
277   SetOperandBytesAccessed(0, 0);
278   SetOperandBytesAccessed(
279       1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
280   SetOperandBytesAccessed(
281       2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
282   return OkStatus();
283 }
284 
HandleTuple(const HloInstruction * tuple)285 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
286   // The tuple instruction only gathers pointers from inputs (it doesn't iterate
287   // through them). The memory touched is then only the size of the output
288   // index table of the tuple.
289 
290   current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
291   SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
292   for (int i = 0; i < tuple->operand_count(); ++i) {
293     SetOperandBytesAccessed(i, 0);
294   }
295   return OkStatus();
296 }
297 
HandleConcatenate(const HloInstruction *)298 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
299   return OkStatus();
300 }
301 
HandleConvert(const HloInstruction * convert)302 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
303   return HandleElementwiseOp(convert);
304 }
305 
HandleCopy(const HloInstruction *)306 Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return OkStatus(); }
307 
HandleDomain(const HloInstruction * domain)308 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
309   // Domain does not have any computation or data transfer.
310   current_should_compute_bottleneck_time_ = false;
311   current_properties_[kBytesAccessedKey] = 0;
312   SetOutputBytesAccessed(0);
313   for (int i = 0; i < domain->operand_count(); ++i) {
314     SetOperandBytesAccessed(i, 0);
315   }
316   current_properties_[kOptimalSecondsKey] = 0;
317   return OkStatus();
318 }
319 
320 /* static */
GetDotFlops(const Shape & lhs_shape,const Shape & result_shape,const DotDimensionNumbers & dnums)321 int64_t HloCostAnalysis::GetDotFlops(const Shape& lhs_shape,
322                                      const Shape& result_shape,
323                                      const DotDimensionNumbers& dnums) {
324   // Count of elements along the reduction dimensions.
325   int64_t reduction_width = 1;
326   for (auto dim : dnums.lhs_contracting_dimensions()) {
327     reduction_width *= lhs_shape.dimensions(dim);
328   }
329   // Each output element requires reduction_width FMA operations.
330   return kFmaFlops * ShapeUtil::ElementsIn(result_shape) * reduction_width;
331 }
332 
HandleDot(const HloInstruction * dot)333 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
334   current_properties_[kFlopsKey] = GetDotFlops(
335       dot->operand(0)->shape(), dot->shape(), dot->dot_dimension_numbers());
336   return OkStatus();
337 }
338 
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340   // Count nested infeed output tuples.
341   int64_t size = 0;
342   for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343     size += GetShapeSize(indexed_shape.shape);
344     SetOutputBytesAccessed(indexed_shape.index,
345                            GetShapeSize(indexed_shape.shape));
346   }
347   SetOutputBytesAccessed(size);
348   current_properties_[kBytesAccessedKey] = size;
349   return OkStatus();
350 }
351 
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353   // Count nested outfeed operand tuples.
354   current_properties_[kBytesAccessedKey] = 0;
355   for (int64_t i = 0; i < outfeed->operand_count(); ++i) {
356     const HloInstruction* operand = outfeed->operand(i);
357     int64_t size = 0;
358     for (const auto& indexed_shape :
359          ShapeUtil::GetLeafShapes(operand->shape())) {
360       size += GetShapeSize(indexed_shape.shape);
361       SetOperandBytesAccessed(i, indexed_shape.index,
362                               GetShapeSize(indexed_shape.shape));
363     }
364     SetOperandBytesAccessed(i, size);
365     current_properties_[kBytesAccessedKey] += size;
366   }
367   return OkStatus();
368 }
369 
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371   // Compute properties of the mapped function.
372   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373                       ProcessSubcomputation(map->to_apply()));
374 
375   // Compute the cost of all elements for this Map operation.
376   const int64_t element_count = ShapeUtil::ElementsIn(map->shape());
377   for (const auto& property : sub_properties) {
378     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379       current_properties_[property.first] = property.second * element_count;
380     }
381   }
382   return OkStatus();
383 }
384 
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386   HloComputation* function = reduce->to_apply();
387   // Compute the cost of the user function.
388   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389                       ProcessSubcomputation(function));
390 
391   // Compute the cost of all elements for this Reduce operation.
392   // This counts the number of times the reduction function is applied, so it
393   // does not need to be multiplied by the number of input tensors - that's
394   // already "priced in" by the sub-computation doing more work.
395   auto arg = reduce->operand(0);
396   auto output_shape = reduce->shape().IsArray()
397                           ? reduce->shape()
398                           : reduce->shape().tuple_shapes(0);
399   int64_t reduction_count =
400       ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401   for (const auto& property : sub_properties) {
402     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403       current_properties_[property.first] = property.second * reduction_count;
404     }
405   }
406   return OkStatus();
407 }
408 
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410     const HloInstruction* reduce_window) {
411   const Window& window = reduce_window->window();
412   auto function = reduce_window->to_apply();
413   // Compute the properties of the reduction function.
414   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415                       ProcessSubcomputation(function));
416 
417   // Compute the cost of all elements for this ReduceWindow operation. For each
418   // output element there are window_size - 1 reductions to perform.
419   int64_t window_element_count = 1;
420   for (const auto& dimension : window.dimensions()) {
421     window_element_count *= dimension.size();
422   }
423 
424   const int64_t output_element_count =
425       ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426                                 ? reduce_window->shape()
427                                 : reduce_window->shape().tuple_shapes(0));
428   const int64_t reduction_count =
429       (window_element_count - 1) * output_element_count;
430   for (const auto& property : sub_properties) {
431     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432       current_properties_[property.first] = property.second * reduction_count;
433     }
434   }
435   return OkStatus();
436 }
437 
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439     const HloInstruction* instruction) {
440   // Compute the properties of the select and scatter function.
441   // Compute the properties of the reduction function.
442   TF_ASSIGN_OR_RETURN(const Properties select_properties,
443                       ProcessSubcomputation(instruction->select()));
444   TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445                       ProcessSubcomputation(instruction->scatter()));
446 
447   // Compute the cost of all elements for this operation. For each scatter
448   // source element there are window_size - 1 select computations to perform and
449   // 1 scatter computation to perform.
450   const auto source = instruction->operand(1);
451   const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452   int64_t window_element_count = 1;
453   for (const auto& dimension : instruction->window().dimensions()) {
454     window_element_count *= dimension.size();
455   }
456   const int64_t select_count =
457       source_element_count * (window_element_count - 1);
458   for (const auto& property : select_properties) {
459     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
460       current_properties_[property.first] += property.second * select_count;
461     }
462   }
463   for (const auto& property : scatter_properties) {
464     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
465       current_properties_[property.first] +=
466           property.second * source_element_count;
467     }
468   }
469   return OkStatus();
470 }
471 
HandleBitcast(const HloInstruction *)472 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
473   // A bitcast does no computation and touches no memory.
474   current_properties_[kBytesAccessedKey] = 0;
475   SetOutputBytesAccessed(0);
476   SetOperandBytesAccessed(0, 0);
477   current_properties_[kOptimalSecondsKey] = 0;
478   return OkStatus();
479 }
480 
HandleBroadcast(const HloInstruction *)481 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
482   return OkStatus();
483 }
484 
HandlePad(const HloInstruction *)485 Status HloCostAnalysis::HandlePad(const HloInstruction*) { return OkStatus(); }
486 
HandleAsyncStart(const HloInstruction * async_start)487 Status HloCostAnalysis::HandleAsyncStart(const HloInstruction* async_start) {
488   TF_ASSIGN_OR_RETURN(
489       current_properties_,
490       ProcessSubcomputation(async_start->called_computations()[0]));
491   return OkStatus();
492 }
493 
HandleAsyncUpdate(const HloInstruction *)494 Status HloCostAnalysis::HandleAsyncUpdate(const HloInstruction*) {
495   return OkStatus();
496 }
497 
HandleAsyncDone(const HloInstruction *)498 Status HloCostAnalysis::HandleAsyncDone(const HloInstruction*) {
499   return OkStatus();
500 }
501 
HandleCopyStart(const HloInstruction *)502 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
503   return OkStatus();
504 }
505 
HandleCopyDone(const HloInstruction *)506 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
507   return OkStatus();
508 }
509 
HandleSend(const HloInstruction *)510 Status HloCostAnalysis::HandleSend(const HloInstruction*) { return OkStatus(); }
511 
HandleSendDone(const HloInstruction *)512 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
513   return OkStatus();
514 }
515 
HandleRecv(const HloInstruction *)516 Status HloCostAnalysis::HandleRecv(const HloInstruction*) { return OkStatus(); }
517 
HandleRecvDone(const HloInstruction *)518 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
519   return OkStatus();
520 }
521 
HandleReshape(const HloInstruction *)522 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
523   return OkStatus();
524 }
525 
HandleDynamicReshape(const HloInstruction *)526 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
527   return OkStatus();
528 }
529 
HandleBatchNormTraining(const HloInstruction *)530 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
531   // TODO(b/62294698): Implement cost analysis for batch-norm-training.
532   return OkStatus();
533 }
534 
HandleBatchNormInference(const HloInstruction *)535 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
536   // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
537   return OkStatus();
538 }
539 
HandleBatchNormGrad(const HloInstruction *)540 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
541   // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
542   return OkStatus();
543 }
544 
HandleTranspose(const HloInstruction * transpose)545 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
546   if (transpose->IsEffectiveBitcast()) {
547     return HandleBitcast(transpose);
548   }
549   return OkStatus();
550 }
551 
HandleAfterAll(const HloInstruction * token)552 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
553   // This instruction is used to enforce ordering at compile time. No code is
554   // emitted.
555   current_should_compute_bottleneck_time_ = false;
556   current_properties_[kBytesAccessedKey] = 0;
557   SetOutputBytesAccessed(0);
558   for (int i = 0; i < token->operand_count(); ++i) {
559     SetOperandBytesAccessed(i, 0);
560   }
561   current_properties_[kOptimalSecondsKey] = 0;
562   return OkStatus();
563 }
564 
HandleAddDependency(const HloInstruction * add_dependency)565 Status HloCostAnalysis::HandleAddDependency(
566     const HloInstruction* add_dependency) {
567   // This instruction is used to enforce ordering at compile time. No code is
568   // emitted.
569   current_should_compute_bottleneck_time_ = false;
570   current_properties_[kBytesAccessedKey] = 0;
571   SetOutputBytesAccessed(0);
572   for (int i = 0; i < add_dependency->operand_count(); ++i) {
573     SetOperandBytesAccessed(i, 0);
574   }
575   current_properties_[kOptimalSecondsKey] = 0;
576   return OkStatus();
577 }
578 
GetConvolutionFlops(const HloInstruction * convolution)579 int64_t HloCostAnalysis::GetConvolutionFlops(
580     const HloInstruction* convolution) {
581   auto lhs = convolution->operand(0);
582   auto rhs = convolution->operand(1);
583   const Shape& lhs_shape = lhs->shape();
584   const Shape& rhs_shape = rhs->shape();
585   const Shape& result_shape = convolution->shape();
586 
587   return GetConvolutionFlops(convolution, lhs_shape, rhs_shape, result_shape);
588 }
589 
590 /* static */
GetConvolutionFlops(const HloInstruction * convolution,const Shape & lhs_shape,const Shape & rhs_shape,const Shape & result_shape)591 int64_t HloCostAnalysis::GetConvolutionFlops(const HloInstruction* convolution,
592                                              const Shape& lhs_shape,
593                                              const Shape& rhs_shape,
594                                              const Shape& result_shape) {
595   Window window = convolution->window();
596   const auto& dnums = convolution->convolution_dimension_numbers();
597   const int64_t input_batch_dim = dnums.input_batch_dimension();
598   const int64_t input_feature_dim = dnums.input_feature_dimension();
599   const int64_t output_feature_dim = dnums.output_feature_dimension();
600   const int64_t input_feature =
601       ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
602   const int64_t output_feature =
603       ShapeUtil::GetDimension(result_shape, output_feature_dim);
604   const int64_t batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
605 
606   DimensionVector kernel_limits;
607   DimensionVector output_limits;
608   DimensionVector input_limits;
609   if (window.dimensions().empty()) {
610     window = window_util::MakeWindow({1});
611     kernel_limits.push_back(1);
612     output_limits.push_back(1);
613     input_limits.push_back(1);
614   } else {
615     for (int64_t spatial_dimension = 0;
616          spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
617       // Spatial dimension number for kernel (rhs).
618       const int64_t kernel_spatial_dim =
619           dnums.kernel_spatial_dimensions(spatial_dimension);
620       const int64_t kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
621       kernel_limits.push_back(kernel_limit);
622 
623       // Spatial dimension number for output.
624       const int64_t output_spatial_dim =
625           dnums.output_spatial_dimensions(spatial_dimension);
626       const int64_t output_limit = result_shape.dimensions(output_spatial_dim);
627       output_limits.push_back(output_limit);
628 
629       // Spatial dimension number for input (lhs).
630       const int64_t input_spatial_dim =
631           dnums.input_spatial_dimensions(spatial_dimension);
632       const int64_t input_limit = lhs_shape.dimensions(input_spatial_dim);
633       input_limits.push_back(input_limit);
634     }
635   }
636 
637   DimensionVector valid_position_counts;
638 
639   // Loop over each spatial dimension.
640   for (int64_t spatial_dimension = 0;
641        spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
642     const auto& window_dim = window.dimensions(spatial_dimension);
643     // These two conditions will create an N^2 iteration pattern with only N
644     // valid elements. This is a performance optimization and produces the same
645     // result as the whole loop.
646     if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
647         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
648         input_limits[spatial_dimension] == window_dim.base_dilation() &&
649         window_dim.window_dilation() == 1 &&
650         std::max<int64_t>(1, input_limits[spatial_dimension] - 1) ==
651             window_dim.stride() &&
652         window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
653       valid_position_counts.push_back(input_limits[spatial_dimension]);
654       continue;
655     }
656 
657     if (input_limits[spatial_dimension] == 1 &&
658         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
659         window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
660         window_dim.stride() == 1 &&
661         window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
662         window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
663       valid_position_counts.push_back(output_limits[spatial_dimension]);
664       continue;
665     }
666 
667     int64_t valid_position_count = 0;
668     // Loop over each point in the kernel.
669     for (int64_t kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
670          ++kernel_idx) {
671       // Loop over each point in the output.
672       for (int64_t output_idx = 0;
673            output_idx < output_limits[spatial_dimension]; ++output_idx) {
674         // Calculate lhs (input) index without taking base dilation into
675         // account.
676         const int64_t undilated_index =
677             output_idx * window_dim.stride() - window_dim.padding_low() +
678             kernel_idx * window_dim.window_dilation();
679 
680         // Calculate the actual lhs (input) index after dilation. Avoid the
681         // division as an optimization.
682         const int64_t lhs_spatial_index =
683             window_dim.base_dilation() > 1
684                 ? undilated_index / window_dim.base_dilation()
685                 : undilated_index;
686 
687         // Skip if the lhs (input) index is to be dilated.
688         if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
689           continue;
690         }
691 
692         // Skip if input index is not in bound.
693         if (lhs_spatial_index < 0 ||
694             lhs_spatial_index >= input_limits[spatial_dimension]) {
695           continue;
696         }
697 
698         valid_position_count += 1;
699       }
700     }
701     valid_position_counts.push_back(valid_position_count);
702   }
703 
704   const int64_t fma_count =
705       (input_feature / convolution->feature_group_count()) * output_feature *
706       (batch / convolution->batch_group_count()) *
707       Product(valid_position_counts);
708   return fma_count * kFmaFlops;
709 }
710 
HandleConvolution(const HloInstruction * convolution)711 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
712   current_properties_[kFlopsKey] = GetConvolutionFlops(convolution);
713   return OkStatus();
714 }
715 
HandleFft(const HloInstruction * fft)716 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
717   auto real_shape =
718       fft->operand(0)->shape().IsTuple()
719           ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
720           : fft->operand(0)->shape();
721   constexpr int kFmaPerComplexMul = 4;
722   int64_t log_factors = 1;
723   for (int64_t dim : fft->fft_length()) {
724     log_factors *= Log2Floor<uint64_t>(dim);
725   }
726   current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
727                                    ShapeUtil::ElementsIn(real_shape);
728   return OkStatus();
729 }
730 
HandleTriangularSolve(const HloInstruction * hlo)731 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
732   // Half of operand 0 is read.
733   float bytes_accessed = GetShapeSize(hlo->shape());
734   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
735   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
736   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
737   bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
738   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
739   current_properties_[kBytesAccessedKey] = bytes_accessed;
740 
741   const Shape& a_shape = hlo->operand(0)->shape();
742   const Shape& b_shape = hlo->operand(1)->shape();
743   // Estimate as batch * mn^2 / 2 flops.
744   int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
745   elems *= ShapeUtil::ElementsIn(b_shape);
746   current_properties_[kFlopsKey] = kFmaFlops * elems;
747   return OkStatus();
748 }
749 
HandleCholesky(const HloInstruction * hlo)750 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
751   // Half of operand 0 is read and half of the output will be written.
752   float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
753   SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
754   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
755   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
756   current_properties_[kBytesAccessedKey] = bytes_accessed;
757 
758   const Shape& a_shape = hlo->operand(0)->shape();
759   // Estimate as batch * n^3 / 3 flops.
760   int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
761   elems *= ShapeUtil::ElementsIn(a_shape);
762   current_properties_[kFlopsKey] = elems / 3;
763   return OkStatus();
764 }
765 
HandleOptimizationBarrier(const HloInstruction *)766 Status HloCostAnalysis::HandleOptimizationBarrier(
767     const HloInstruction* /*hlo*/) {
768   return OkStatus();
769 }
770 
HandleAllGather(const HloInstruction *)771 Status HloCostAnalysis::HandleAllGather(const HloInstruction* /*hlo*/) {
772   return OkStatus();
773 }
774 
HandleAllGatherStart(const HloInstruction * hlo)775 Status HloCostAnalysis::HandleAllGatherStart(const HloInstruction* hlo) {
776   return HandleAllGather(hlo);
777 }
778 
HandleAllGatherDone(const HloInstruction *)779 Status HloCostAnalysis::HandleAllGatherDone(const HloInstruction* /*hlo*/) {
780   return OkStatus();
781 }
782 
HandleAllReduce(const HloInstruction * crs)783 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
784   // We assume 2 replicas, so that each output element is the sum of two input
785   // elements.
786   //
787   // TODO(b/33004697): Compute correct cost here, taking the actual number of
788   // replicas into account.
789   double flops = 0.0;
790   int64_t output_bytes_accessed = 0;
791   ShapeUtil::ForEachSubshape(
792       crs->shape(), [&](const Shape& subshape, const ShapeIndex&) {
793         if (subshape.IsArray()) {
794           flops += ShapeUtil::ElementsIn(subshape);
795           output_bytes_accessed += GetShapeSize(subshape);
796         }
797       });
798   int64_t bytes_accessed = output_bytes_accessed;
799   for (const HloInstruction* operand : crs->operands()) {
800     bytes_accessed += GetShapeSize(operand->shape());
801   }
802   current_properties_[kFlopsKey] = flops;
803   SetOutputBytesAccessed(output_bytes_accessed);
804   current_properties_[kBytesAccessedKey] = bytes_accessed;
805   return OkStatus();
806 }
807 
HandleReduceScatter(const HloInstruction * hlo)808 Status HloCostAnalysis::HandleReduceScatter(const HloInstruction* hlo) {
809   return OkStatus();
810 }
811 
HandleAllReduceStart(const HloInstruction * hlo)812 Status HloCostAnalysis::HandleAllReduceStart(const HloInstruction* hlo) {
813   return HandleAllReduce(hlo);
814 }
815 
HandleAllReduceDone(const HloInstruction *)816 Status HloCostAnalysis::HandleAllReduceDone(const HloInstruction* /*hlo*/) {
817   return OkStatus();
818 }
819 
HandleAllToAll(const HloInstruction * hlo)820 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
821   return OkStatus();
822 }
823 
HandleCollectivePermute(const HloInstruction *)824 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
825   return OkStatus();
826 }
827 
HandleCollectivePermuteStart(const HloInstruction *)828 Status HloCostAnalysis::HandleCollectivePermuteStart(
829     const HloInstruction* /*hlo*/) {
830   return OkStatus();
831 }
832 
HandleCollectivePermuteDone(const HloInstruction *)833 Status HloCostAnalysis::HandleCollectivePermuteDone(
834     const HloInstruction* /*hlo*/) {
835   return OkStatus();
836 }
837 
HandlePartitionId(const HloInstruction *)838 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
839   return OkStatus();
840 }
841 
HandleReplicaId(const HloInstruction *)842 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
843   return OkStatus();
844 }
845 
HandleRng(const HloInstruction * random)846 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
847   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
848   // cost changes with the implementation and the distribution. For now, assume
849   // the cost of each RNG is same as a transcendental operation.
850   current_properties_[kTranscendentalsKey] =
851       ShapeUtil::ElementsIn(random->shape());
852   return OkStatus();
853 }
854 
HandleRngBitGenerator(const HloInstruction * random)855 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
856   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
857   // cost changes with the implementation and the distribution. For now, assume
858   // the cost of each RNG is same as a transcendental operation.
859   current_properties_[kTranscendentalsKey] =
860       ShapeUtil::ElementsInRecursive(random->shape());
861   return OkStatus();
862 }
863 
HandleRngGetAndUpdateState(const HloInstruction * random)864 Status HloCostAnalysis::HandleRngGetAndUpdateState(
865     const HloInstruction* random) {
866   return OkStatus();
867 }
868 
HandleFusion(const HloInstruction * fusion)869 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
870   if (fusion->IsCustomFusion()) {
871     for (const HloInstruction* hlo :
872          fusion->fused_instructions_computation()->instructions()) {
873       if (hlo->opcode() == HloOpcode::kGather) {
874         return HandleGather(hlo);
875       }
876       if (hlo->opcode() == HloOpcode::kScatter) {
877         return HandleScatter(hlo);
878       }
879     }
880   }
881   TF_ASSIGN_OR_RETURN(
882       current_properties_,
883       ProcessSubcomputation(fusion->fused_instructions_computation()));
884 
885   // Fusion nodes that produce a tuple also produce the entries in the tuple.
886   // Ignore the memory accessed inside fused ops, since fusion is supposed to
887   // prevent intermediate data from touching slow memory.
888   current_properties_[kBytesAccessedKey] = 0;
889   ShapeUtil::ForEachSubshape(
890       fusion->shape(),
891       [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
892         if (!subshape.IsArray()) {
893           return;
894         }
895         if (shape_index.empty()) {
896           if (fusion->fused_expression_root()->opcode() ==
897               HloOpcode::kDynamicUpdateSlice) {
898             int64_t size = GetShapeSize(
899                 fusion->fused_expression_root()->operand(1)->shape());
900             current_properties_[kBytesAccessedKey] += size;
901             SetOutputBytesAccessed(shape_index, size);
902             return;
903           }
904         } else if (shape_index.size() == 1) {
905           if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
906               fusion->fused_expression_root()
907                       ->operand(shape_index[0])
908                       ->opcode() == HloOpcode::kDynamicUpdateSlice) {
909             int64_t size = GetShapeSize(fusion->fused_expression_root()
910                                             ->operand(shape_index[0])
911                                             ->operand(1)
912                                             ->shape());
913             current_properties_[kBytesAccessedKey] += size;
914             SetOutputBytesAccessed(shape_index, size);
915             return;
916           }
917         }
918         current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
919         SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
920       });
921 
922   if (fusion->shape().IsTuple()) {
923     // Propagate and accumulate the output tuple bytes from the tuple subshapes.
924     // This ensures we have the correct output bytes accessed for the shape
925     // index
926     // {}.
927     std::function<float(const Shape&, const ShapeIndex&)>
928         propagate_output_size_to_parent;
929     propagate_output_size_to_parent = [&](const Shape& shape,
930                                           const ShapeIndex& shape_index) {
931       auto output_bytes_it =
932           current_properties_.find(GetOutputBytesAccessedKey(shape_index));
933       if (output_bytes_it != current_properties_.end()) {
934         return output_bytes_it->second;
935       }
936       float bytes_accessed = 0;
937       for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
938         const Shape& subshape = shape.tuple_shapes(i);
939         ShapeIndex subshape_index(shape_index);
940         subshape_index.push_back(i);
941         bytes_accessed +=
942             propagate_output_size_to_parent(subshape, subshape_index);
943       }
944       SetOutputBytesAccessed(shape_index, bytes_accessed);
945       return bytes_accessed;
946     };
947     auto output_bytes_it =
948         current_properties_.find(GetOutputBytesAccessedKey());
949     if (output_bytes_it != current_properties_.end()) {
950       current_properties_.erase(output_bytes_it);
951     }
952     propagate_output_size_to_parent(fusion->shape(), {});
953   }
954 
955   for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) {
956     const HloInstruction* operand = fusion->fused_parameter(i);
957     int64_t operand_size = 0;
958     if (!operand->shape().IsTuple()) {
959       operand_size = FusionParameterReadBytes(operand);
960     } else {
961       // If the fusion parameter is a tuple type, find the gte for the leaf
962       // shape and calculate the bytes accessed for those array types.
963       for (const auto& indexed_shape :
964            ShapeUtil::GetLeafShapes(operand->shape())) {
965         const HloInstruction* gte = operand;
966         for (int64_t index : indexed_shape.index) {
967           for (const HloInstruction* user : gte->users()) {
968             if (user->opcode() == HloOpcode::kGetTupleElement &&
969                 user->tuple_index() == index) {
970               gte = user;
971               break;
972             }
973           }
974         }
975         int64_t size = FusionParameterReadBytes(gte);
976         operand_size += size;
977         SetOperandBytesAccessed(i, indexed_shape.index, size);
978       }
979     }
980     current_properties_[kBytesAccessedKey] += operand_size;
981     SetOperandBytesAccessed(i, operand_size);
982   }
983 
984   return OkStatus();
985 }
986 
HandleCall(const HloInstruction * call)987 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
988   TF_ASSIGN_OR_RETURN(current_properties_,
989                       ProcessSubcomputation(call->to_apply()));
990   current_should_compute_bottleneck_time_ = false;
991   return OkStatus();
992 }
993 
HandleCustomCall(const HloInstruction * custom_call)994 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
995   // Mark applicable fields as "unknown", since we don't know what this
996   // CustomCall does.  This is better than returning an error, which would stop
997   // iteration, and therefore would prevent us from getting *any* stats for a
998   // computation which contains a CustomCall.
999   current_properties_[kOptimalSecondsKey] = -1;
1000   current_properties_[kBytesAccessedKey] = -1;
1001   SetOutputBytesAccessed(-1);
1002   for (int i = 0; i < custom_call->operand_count(); ++i) {
1003     SetOperandBytesAccessed(i, -1);
1004   }
1005   current_properties_[kFlopsKey] = -1;
1006   current_should_compute_bottleneck_time_ = false;
1007   return OkStatus();
1008 }
1009 
HandleSort(const HloInstruction * sort)1010 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
1011   // This assumes a comparison based N*log(N) algorithm. As for all ops, the
1012   // actual properties of the op depend on the backend implementation.
1013   int64_t elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
1014   current_properties_[kFlopsKey] = elements * Log2Ceiling<uint64_t>(elements);
1015   return OkStatus();
1016 }
1017 
HandleWhile(const HloInstruction * xla_while)1018 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
1019   // Since the number of iterations of the while node will not always be
1020   // something that we can statically analyze, we cannot precisely compute the
1021   // cost of a while node. For now compute the cost of a single iteration.
1022   TF_ASSIGN_OR_RETURN(const Properties body_properties,
1023                       ProcessSubcomputation(xla_while->while_body()));
1024 
1025   TF_ASSIGN_OR_RETURN(const Properties condition_properties,
1026                       ProcessSubcomputation(xla_while->while_condition()));
1027 
1028   current_properties_.clear();
1029   for (const auto& property : body_properties) {
1030     current_properties_[property.first] += property.second;
1031   }
1032   for (const auto& property : condition_properties) {
1033     current_properties_[property.first] += property.second;
1034   }
1035   current_should_compute_bottleneck_time_ = false;
1036 
1037   return OkStatus();
1038 }
1039 
HandleConditional(const HloInstruction * conditional)1040 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
1041   // Compute the cost of the branch computations and take the maximum from those
1042   // for each property.
1043   TF_ASSIGN_OR_RETURN(
1044       const Properties branch0_computation_properties,
1045       ProcessSubcomputation(conditional->branch_computation(0)));
1046   current_properties_ = branch0_computation_properties;
1047   for (int j = 1; j < conditional->branch_count(); ++j) {
1048     TF_ASSIGN_OR_RETURN(
1049         const Properties branch_computation_properties,
1050         ProcessSubcomputation(conditional->branch_computation(j)));
1051     for (const auto& property : branch_computation_properties) {
1052       if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_,
1053                                                property)) {
1054         auto& current_property = current_properties_[property.first];
1055         current_property = std::max(current_property, property.second);
1056       }
1057     }
1058   }
1059   current_should_compute_bottleneck_time_ = false;
1060 
1061   return OkStatus();
1062 }
1063 
HandleGather(const HloInstruction * gather)1064 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1065   // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1066   // size of the output shape and a read of the gather indices.
1067   int64_t output_size = GetShapeSize(gather->shape());
1068   current_properties_[kBytesAccessedKey] =
1069       output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1070   SetOperandBytesAccessed(0, output_size);
1071   SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1072   SetOutputBytesAccessed(output_size);
1073   // Gather does not issue any flops.
1074   return OkStatus();
1075 }
1076 
HandleScatter(const HloInstruction * hlo)1077 Status HloCostAnalysis::HandleScatter(const HloInstruction* hlo) {
1078   auto* scatter = Cast<HloScatterInstruction>(hlo);
1079   // Scatter accesses the equivalent of 3N update shapes (input, output, and
1080   // updates), and the scatter indices.
1081   int64_t total_update_size = 0;
1082   for (int i = 0, n = scatter->scatter_operand_count(); i < n; ++i) {
1083     int64_t update_size = GetShapeSize(scatter->scatter_updates()[i]->shape());
1084     SetOperandBytesAccessed(i, update_size);
1085     SetOperandBytesAccessed(n + 1 + i, update_size);
1086     total_update_size += update_size;
1087   }
1088   int64_t scatter_indices_size =
1089       GetShapeSize(scatter->scatter_indices()->shape());
1090   SetOperandBytesAccessed(scatter->scatter_operand_count(),
1091                           scatter_indices_size);
1092   current_properties_[kBytesAccessedKey] =
1093       total_update_size * 3 + scatter_indices_size;
1094   SetOutputBytesAccessed(total_update_size);
1095   const int64_t element_count =
1096       ShapeUtil::ElementsIn(scatter->scatter_updates()[0]->shape());
1097   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1098                       ProcessSubcomputation(scatter->to_apply()));
1099   for (const auto& property : sub_properties) {
1100     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1101       current_properties_[property.first] = property.second * element_count;
1102     }
1103   }
1104   return OkStatus();
1105 }
1106 
HandleGetDimensionSize(const HloInstruction *)1107 Status HloCostAnalysis::HandleGetDimensionSize(
1108     const HloInstruction* /*get_size*/) {
1109   return OkStatus();
1110 }
1111 
HandleSetDimensionSize(const HloInstruction *)1112 Status HloCostAnalysis::HandleSetDimensionSize(
1113     const HloInstruction* /*set_size*/) {
1114   return OkStatus();
1115 }
1116 
FinishVisit(const HloInstruction *)1117 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1118   return OkStatus();
1119 }
1120 
flop_count() const1121 float HloCostAnalysis::flop_count() const {
1122   return GetProperty(kFlopsKey, properties_sum_);
1123 }
1124 
transcendental_count() const1125 float HloCostAnalysis::transcendental_count() const {
1126   return GetProperty(kTranscendentalsKey, properties_sum_);
1127 }
1128 
bytes_accessed() const1129 float HloCostAnalysis::bytes_accessed() const {
1130   return GetProperty(kBytesAccessedKey, properties_sum_);
1131 }
1132 
optimal_seconds() const1133 float HloCostAnalysis::optimal_seconds() const {
1134   return GetProperty(kOptimalSecondsKey, properties_sum_);
1135 }
1136 
flop_count(const HloInstruction & hlo) const1137 int64_t HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1138   return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1139 }
1140 
transcendental_count(const HloInstruction & hlo) const1141 int64_t HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1142   return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1143 }
1144 
bytes_accessed(const HloInstruction & hlo) const1145 int64_t HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1146   return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1147 }
1148 
operand_bytes_accessed(const HloInstruction & hlo,int64_t operand_num,ShapeIndex index) const1149 int64_t HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1150                                                 int64_t operand_num,
1151                                                 ShapeIndex index) const {
1152   return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1153                            hlo_properties_);
1154 }
1155 
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1156 int64_t HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1157                                                ShapeIndex index) const {
1158   return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1159                            hlo_properties_);
1160 }
1161 
optimal_seconds(const HloInstruction & hlo) const1162 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1163   return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1164 }
1165 
GetBytesRead(const HloInstruction & hlo,std::optional<int64_t> memory_space) const1166 int64_t HloCostAnalysis::GetBytesRead(
1167     const HloInstruction& hlo, std::optional<int64_t> memory_space) const {
1168   int64_t bytes_read = 0;
1169   for (int operand_number = 0; operand_number < hlo.operand_count();
1170        ++operand_number) {
1171     const Shape& shape = hlo.operand(operand_number)->shape();
1172     ShapeUtil::ForEachSubshape(
1173         shape, [&](const Shape& sub_shape, const ShapeIndex& index) {
1174           if (ShapeUtil::IsLeafIndex(shape, index)) {
1175             std::optional<int64_t> index_memory_space;
1176             if (sub_shape.has_layout()) {
1177               index_memory_space = sub_shape.layout().memory_space();
1178             }
1179             if (!memory_space || memory_space == index_memory_space) {
1180               bytes_read += operand_bytes_accessed(hlo, operand_number, index);
1181             }
1182           }
1183         });
1184   }
1185   return bytes_read;
1186 }
1187 
GetBytesWritten(const HloInstruction & hlo,std::optional<int64_t> memory_space) const1188 int64_t HloCostAnalysis::GetBytesWritten(
1189     const HloInstruction& hlo, std::optional<int64_t> memory_space) const {
1190   int64_t bytes_written = 0;
1191   for (const ShapeUtil::IndexedShape& indexed_shape :
1192        ShapeUtil::GetLeafShapes(hlo.shape())) {
1193     std::optional<int64_t> index_memory_space;
1194     if (indexed_shape.shape.has_layout()) {
1195       index_memory_space = indexed_shape.shape.layout().memory_space();
1196     }
1197     if (!memory_space || memory_space == index_memory_space) {
1198       bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1199     }
1200   }
1201   return bytes_written;
1202 }
1203 
ProcessSubcomputation(HloComputation * computation)1204 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1205     HloComputation* computation) {
1206   auto visitor = CreateNestedCostAnalysis();
1207   visitor->ReserveVisitStates(computation->instruction_count());
1208   TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1209   hlo_properties_.insert(visitor->hlo_properties_.begin(),
1210                          visitor->hlo_properties_.end());
1211   return visitor->properties();
1212 }
1213 
CreateNestedCostAnalysis()1214 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis() {
1215   return std::make_unique<HloCostAnalysis>(options_);
1216 }
1217 
SetOperandBytesAccessed(int64_t operand_num,float value)1218 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1219                                               float value) {
1220   current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1221 }
1222 
SetOperandBytesAccessed(int64_t operand_num,ShapeIndex index,float value)1223 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1224                                               ShapeIndex index, float value) {
1225   current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1226       value;
1227 }
1228 
SetOutputBytesAccessed(float value)1229 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1230   current_properties_[GetOutputBytesAccessedKey()] = value;
1231 }
1232 
SetOutputBytesAccessed(ShapeIndex index,float value)1233 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1234   current_properties_[GetOutputBytesAccessedKey(index)] = value;
1235 }
1236 
GetOperandBytesAccessedKey(int64_t operand_num,ShapeIndex index)1237 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1238     int64_t operand_num, ShapeIndex index) {
1239   return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1240                       index.ToString());
1241 }
1242 
GetOutputBytesAccessedKey(ShapeIndex index)1243 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1244     ShapeIndex index) {
1245   return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1246 }
1247 
1248 }  // namespace xla
1249