1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/utils/hlo_proto_map.h"
17
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/core/profiler/convert/xla_op_utils.h"
31 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
32 #include "tensorflow/core/profiler/utils/xplane_schema.h"
33 #include "tensorflow/core/profiler/utils/xplane_utils.h"
34 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
35
36 namespace tensorflow {
37 namespace profiler {
38 namespace {
39
NumHeapSimulatorTraceEvents(const xla::HloProto * hlo)40 int NumHeapSimulatorTraceEvents(const xla::HloProto* hlo) {
41 int result = 0;
42 for (const auto& trace : hlo->buffer_assignment().heap_simulator_traces()) {
43 result += trace.events_size();
44 }
45 return result;
46 }
47
48 } // namespace
49
50 std::vector<std::pair<uint64_t, std::unique_ptr<xla::HloProto>>>
ParseHloProtosFromXSpace(const XSpace & space)51 ParseHloProtosFromXSpace(const XSpace& space) {
52 std::vector<std::pair<uint64_t, std::unique_ptr<xla::HloProto>>> hlo_protos;
53 const XPlane* raw_plane = FindPlaneWithName(space, kMetadataPlaneName);
54 if (raw_plane != nullptr) {
55 XPlaneVisitor plane = CreateTfXPlaneVisitor(raw_plane);
56 if (raw_plane->stats_size() > 0) {
57 // Fallback for legacy aggregated XPlane.
58 // TODO(b/235990417): Remove after 06/14/2023.
59 plane.ForEachStat([&](const XStatVisitor& stat) {
60 if (stat.ValueCase() != XStat::kBytesValue) return;
61 auto hlo_proto = std::make_unique<xla::HloProto>();
62 absl::string_view byte_value = stat.BytesValue();
63 if (hlo_proto->ParseFromArray(byte_value.data(), byte_value.size())) {
64 hlo_protos.emplace_back(stat.Id(), std::move(hlo_proto));
65 }
66 });
67 } else {
68 const XStatMetadata* hlo_proto_stat_metadata =
69 plane.GetStatMetadataByType(StatType::kHloProto);
70 if (hlo_proto_stat_metadata == nullptr) {
71 // Fallback for legacy XPlane.
72 // TODO(b/235990417): Remove after 06/14/2023.
73 hlo_proto_stat_metadata = plane.GetStatMetadata(StatType::kHloProto);
74 }
75 if (hlo_proto_stat_metadata != nullptr) {
76 plane.ForEachEventMetadata(
77 [&](const XEventMetadataVisitor& event_metadata) {
78 auto hlo_proto_stat = event_metadata.GetStat(
79 StatType::kHloProto, *hlo_proto_stat_metadata);
80 if (!hlo_proto_stat) return;
81 if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return;
82 auto hlo_proto = std::make_unique<xla::HloProto>();
83 absl::string_view byte_value = hlo_proto_stat->BytesValue();
84 if (hlo_proto->ParseFromArray(byte_value.data(),
85 byte_value.size())) {
86 hlo_protos.emplace_back(event_metadata.Id(),
87 std::move(hlo_proto));
88 }
89 });
90 }
91 }
92 }
93 return hlo_protos;
94 }
95
AddHloProto(uint64_t program_id,const xla::HloProto * hlo_proto)96 bool HloProtoMap::AddHloProto(uint64_t program_id,
97 const xla::HloProto* hlo_proto) {
98 bool new_program_id =
99 hlo_protos_by_program_id_.try_emplace(program_id, hlo_proto).second;
100 absl::string_view hlo_module_name = hlo_proto->hlo_module().name();
101 bool new_module_name =
102 hlo_protos_by_name_
103 .try_emplace(HloModuleNameWithProgramId(hlo_module_name, program_id),
104 hlo_proto)
105 .second;
106 return new_program_id || new_module_name;
107 }
108
AddHloProto(uint64_t program_id,std::unique_ptr<const xla::HloProto> hlo_proto)109 void HloProtoMap::AddHloProto(uint64_t program_id,
110 std::unique_ptr<const xla::HloProto> hlo_proto) {
111 if (AddHloProto(program_id, hlo_proto.get())) {
112 // Only add to <owned_hlo_protos_> if <hlo_proto> is new to HloProtoMap.
113 owned_hlo_protos_.push_back(std::move(hlo_proto));
114 }
115 }
116
AddHloProtosFromXSpace(const XSpace & space)117 void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) {
118 for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(space)) {
119 AddHloProto(program_id, std::move(hlo_proto));
120 }
121 }
122
GetModuleList() const123 std::vector<absl::string_view> HloProtoMap::GetModuleList() const {
124 std::vector<absl::string_view> module_list;
125 module_list.reserve(hlo_protos_by_name_.size());
126 for (const auto& [name, hlo_proto] : hlo_protos_by_name_) {
127 module_list.push_back(name);
128 }
129 return module_list;
130 }
131
GetSortedModuleList() const132 std::vector<absl::string_view> HloProtoMap::GetSortedModuleList() const {
133 std::vector<absl::string_view> module_list = GetModuleList();
134 absl::c_sort(module_list);
135 return module_list;
136 }
137
GetSortedModuleListByHeapTraceSize() const138 std::vector<absl::string_view> HloProtoMap::GetSortedModuleListByHeapTraceSize()
139 const {
140 std::vector<std::pair<absl::string_view, const xla::HloProto*>> hlo_protos(
141 hlo_protos_by_name_.begin(), hlo_protos_by_name_.end());
142
143 // Sort the hlo protos by heap trace size and then by hlo module name.
144 // This way trivial computations will be on the bottom of the list.
145 absl::c_stable_sort(hlo_protos, [](const auto& a, const auto& b) {
146 int num_a = NumHeapSimulatorTraceEvents(a.second);
147 int num_b = NumHeapSimulatorTraceEvents(b.second);
148 return std::tie(num_a, b.first) > std::tie(num_b, a.first);
149 });
150
151 std::vector<absl::string_view> module_list;
152 module_list.reserve(hlo_protos.size());
153 for (const auto& [name, hlo_proto] : hlo_protos) {
154 module_list.push_back(name);
155 }
156 return module_list;
157 }
158
GetHloProtoByModuleName(absl::string_view module_name) const159 absl::StatusOr<const xla::HloProto*> HloProtoMap::GetHloProtoByModuleName(
160 absl::string_view module_name) const {
161 auto iter = hlo_protos_by_name_.find(module_name);
162 if (iter != hlo_protos_by_name_.end()) {
163 return iter->second;
164 }
165 return absl::NotFoundError(
166 absl::StrCat("Module name: ", module_name, " is not found."));
167 }
168
169 } // namespace profiler
170 } // namespace tensorflow
171