1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_stats.h"
17
18 #include <string>
19
20 #include "absl/strings/match.h"
21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/status.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/statusor.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace {
33
34 class OpcodeCollector : public ConstDfsHloVisitorWithDefault {
35 public:
GetUniqueOpcodes()36 std::set<std::string> GetUniqueOpcodes() { return opcodes_; }
37
38 protected:
DefaultAction(const xla::HloInstruction * instr)39 Status DefaultAction(const xla::HloInstruction* instr) final {
40 switch (instr->opcode()) {
41 case HloOpcode::kConstant:
42 break;
43 case HloOpcode::kParameter:
44 break;
45 // Unary
46 case HloOpcode::kAbs:
47 case HloOpcode::kCbrt:
48 case HloOpcode::kCeil:
49 case HloOpcode::kCos:
50 case HloOpcode::kExp:
51 case HloOpcode::kExpm1:
52 case HloOpcode::kFloor:
53 case HloOpcode::kLog:
54 case HloOpcode::kLog1p:
55 case HloOpcode::kLogistic:
56 case HloOpcode::kNegate:
57 case HloOpcode::kRoundNearestAfz:
58 case HloOpcode::kRoundNearestEven:
59 case HloOpcode::kRsqrt:
60 case HloOpcode::kSign:
61 case HloOpcode::kSin:
62 case HloOpcode::kSqrt:
63 case HloOpcode::kTanh:
64 // Binary
65 case HloOpcode::kAdd:
66 case HloOpcode::kAtan2:
67 case HloOpcode::kDivide:
68 case HloOpcode::kMultiply:
69 case HloOpcode::kSubtract:
70 opcodes_.insert("cwise");
71 break;
72 default:
73 opcodes_.insert(HloOpcodeString(instr->opcode()));
74 }
75 return Status::OK();
76 }
77
78 private:
79 std::set<std::string> opcodes_;
80 };
81
GetUniqueOpcodes(HloComputation * computation)82 std::set<std::string> GetUniqueOpcodes(HloComputation* computation) {
83 OpcodeCollector collector;
84 if (computation->Accept(&collector) != Status::OK()) {
85 return {};
86 }
87 return collector.GetUniqueOpcodes();
88 }
89
90 } // namespace
91
ToString()92 std::string HloOpcodeHistogram::ToString() {
93 std::string result;
94 for (const auto& entry : *this) {
95 absl::StrAppend(&result, "{", absl::StrJoin(entry.first, ", "),
96 "}: ", entry.second, "\n");
97 }
98 return result;
99 }
100
RunOnModule(HloModule * module)101 Status HloFusionStatsVisitor::RunOnModule(HloModule* module) {
102 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(this));
103 return Status::OK();
104 }
105
ToString()106 std::string HloFusionStatsVisitor::ToString() {
107 return absl::StrCat("HLO Fusion Stats:\n",
108 "Number of fusion ops: ", num_fusions_, "\n",
109 "Number of kLoop fusions: ", num_loop_fusions_, "\n",
110 loop_fusion_opcode_histogram_.ToString(), "\n",
111 "Number of kInput fusions: ", num_input_fusions_, "\n",
112 input_fusion_opcode_histogram_.ToString());
113 }
114
DefaultAction(const xla::HloInstruction * instr)115 Status HloFusionStatsVisitor::DefaultAction(const xla::HloInstruction* instr) {
116 return Status::OK();
117 }
118
HandleFusion(const HloInstruction * fusion)119 Status HloFusionStatsVisitor::HandleFusion(const HloInstruction* fusion) {
120 num_fusions_++;
121 std::set<std::string> opcodes =
122 GetUniqueOpcodes(fusion->fused_instructions_computation());
123 if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
124 num_loop_fusions_++;
125 loop_fusion_opcode_histogram_[opcodes]++;
126 } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kInput) {
127 num_input_fusions_++;
128 input_fusion_opcode_histogram_[opcodes]++;
129 }
130 return Status::OK();
131 }
132
133 } // namespace gpu
134 } // namespace xla
135