xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.h"
17 
18 #include <string>
19 
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/status.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/statusor.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
33 
34 class OpcodeCollector : public ConstDfsHloVisitorWithDefault {
35  public:
GetUniqueOpcodes()36   std::set<std::string> GetUniqueOpcodes() { return opcodes_; }
37 
38  protected:
DefaultAction(const xla::HloInstruction * instr)39   Status DefaultAction(const xla::HloInstruction* instr) final {
40     switch (instr->opcode()) {
41       case HloOpcode::kConstant:
42         break;
43       case HloOpcode::kParameter:
44         break;
45       // Unary
46       case HloOpcode::kAbs:
47       case HloOpcode::kCbrt:
48       case HloOpcode::kCeil:
49       case HloOpcode::kCos:
50       case HloOpcode::kExp:
51       case HloOpcode::kExpm1:
52       case HloOpcode::kFloor:
53       case HloOpcode::kLog:
54       case HloOpcode::kLog1p:
55       case HloOpcode::kLogistic:
56       case HloOpcode::kNegate:
57       case HloOpcode::kRoundNearestAfz:
58       case HloOpcode::kRoundNearestEven:
59       case HloOpcode::kRsqrt:
60       case HloOpcode::kSign:
61       case HloOpcode::kSin:
62       case HloOpcode::kSqrt:
63       case HloOpcode::kTanh:
64       // Binary
65       case HloOpcode::kAdd:
66       case HloOpcode::kAtan2:
67       case HloOpcode::kDivide:
68       case HloOpcode::kMultiply:
69       case HloOpcode::kSubtract:
70         opcodes_.insert("cwise");
71         break;
72       default:
73         opcodes_.insert(HloOpcodeString(instr->opcode()));
74     }
75     return Status::OK();
76   }
77 
78  private:
79   std::set<std::string> opcodes_;
80 };
81 
GetUniqueOpcodes(HloComputation * computation)82 std::set<std::string> GetUniqueOpcodes(HloComputation* computation) {
83   OpcodeCollector collector;
84   if (computation->Accept(&collector) != Status::OK()) {
85     return {};
86   }
87   return collector.GetUniqueOpcodes();
88 }
89 
90 }  // namespace
91 
ToString()92 std::string HloOpcodeHistogram::ToString() {
93   std::string result;
94   for (const auto& entry : *this) {
95     absl::StrAppend(&result, "{", absl::StrJoin(entry.first, ", "),
96                     "}: ", entry.second, "\n");
97   }
98   return result;
99 }
100 
RunOnModule(HloModule * module)101 Status HloFusionStatsVisitor::RunOnModule(HloModule* module) {
102   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(this));
103   return Status::OK();
104 }
105 
ToString()106 std::string HloFusionStatsVisitor::ToString() {
107   return absl::StrCat("HLO Fusion Stats:\n",
108                       "Number of fusion ops: ", num_fusions_, "\n",
109                       "Number of kLoop fusions: ", num_loop_fusions_, "\n",
110                       loop_fusion_opcode_histogram_.ToString(), "\n",
111                       "Number of kInput fusions: ", num_input_fusions_, "\n",
112                       input_fusion_opcode_histogram_.ToString());
113 }
114 
DefaultAction(const xla::HloInstruction * instr)115 Status HloFusionStatsVisitor::DefaultAction(const xla::HloInstruction* instr) {
116   return Status::OK();
117 }
118 
HandleFusion(const HloInstruction * fusion)119 Status HloFusionStatsVisitor::HandleFusion(const HloInstruction* fusion) {
120   num_fusions_++;
121   std::set<std::string> opcodes =
122       GetUniqueOpcodes(fusion->fused_instructions_computation());
123   if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
124     num_loop_fusions_++;
125     loop_fusion_opcode_histogram_[opcodes]++;
126   } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kInput) {
127     num_input_fusions_++;
128     input_fusion_opcode_histogram_[opcodes]++;
129   }
130   return Status::OK();
131 }
132 
133 }  // namespace gpu
134 }  // namespace xla
135