xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/hlo_proto_map.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/hlo_proto_map.h"
17 
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/core/profiler/convert/xla_op_utils.h"
31 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
32 #include "tensorflow/core/profiler/utils/xplane_schema.h"
33 #include "tensorflow/core/profiler/utils/xplane_utils.h"
34 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
35 
36 namespace tensorflow {
37 namespace profiler {
38 namespace {
39 
NumHeapSimulatorTraceEvents(const xla::HloProto * hlo)40 int NumHeapSimulatorTraceEvents(const xla::HloProto* hlo) {
41   int result = 0;
42   for (const auto& trace : hlo->buffer_assignment().heap_simulator_traces()) {
43     result += trace.events_size();
44   }
45   return result;
46 }
47 
48 }  // namespace
49 
50 std::vector<std::pair<uint64_t, std::unique_ptr<xla::HloProto>>>
ParseHloProtosFromXSpace(const XSpace & space)51 ParseHloProtosFromXSpace(const XSpace& space) {
52   std::vector<std::pair<uint64_t, std::unique_ptr<xla::HloProto>>> hlo_protos;
53   const XPlane* raw_plane = FindPlaneWithName(space, kMetadataPlaneName);
54   if (raw_plane != nullptr) {
55     XPlaneVisitor plane = CreateTfXPlaneVisitor(raw_plane);
56     if (raw_plane->stats_size() > 0) {
57       // Fallback for legacy aggregated XPlane.
58       // TODO(b/235990417): Remove after 06/14/2023.
59       plane.ForEachStat([&](const XStatVisitor& stat) {
60         if (stat.ValueCase() != XStat::kBytesValue) return;
61         auto hlo_proto = std::make_unique<xla::HloProto>();
62         absl::string_view byte_value = stat.BytesValue();
63         if (hlo_proto->ParseFromArray(byte_value.data(), byte_value.size())) {
64           hlo_protos.emplace_back(stat.Id(), std::move(hlo_proto));
65         }
66       });
67     } else {
68       const XStatMetadata* hlo_proto_stat_metadata =
69           plane.GetStatMetadataByType(StatType::kHloProto);
70       if (hlo_proto_stat_metadata == nullptr) {
71         // Fallback for legacy XPlane.
72         // TODO(b/235990417): Remove after 06/14/2023.
73         hlo_proto_stat_metadata = plane.GetStatMetadata(StatType::kHloProto);
74       }
75       if (hlo_proto_stat_metadata != nullptr) {
76         plane.ForEachEventMetadata(
77             [&](const XEventMetadataVisitor& event_metadata) {
78               auto hlo_proto_stat = event_metadata.GetStat(
79                   StatType::kHloProto, *hlo_proto_stat_metadata);
80               if (!hlo_proto_stat) return;
81               if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return;
82               auto hlo_proto = std::make_unique<xla::HloProto>();
83               absl::string_view byte_value = hlo_proto_stat->BytesValue();
84               if (hlo_proto->ParseFromArray(byte_value.data(),
85                                             byte_value.size())) {
86                 hlo_protos.emplace_back(event_metadata.Id(),
87                                         std::move(hlo_proto));
88               }
89             });
90       }
91     }
92   }
93   return hlo_protos;
94 }
95 
AddHloProto(uint64_t program_id,const xla::HloProto * hlo_proto)96 bool HloProtoMap::AddHloProto(uint64_t program_id,
97                               const xla::HloProto* hlo_proto) {
98   bool new_program_id =
99       hlo_protos_by_program_id_.try_emplace(program_id, hlo_proto).second;
100   absl::string_view hlo_module_name = hlo_proto->hlo_module().name();
101   bool new_module_name =
102       hlo_protos_by_name_
103           .try_emplace(HloModuleNameWithProgramId(hlo_module_name, program_id),
104                        hlo_proto)
105           .second;
106   return new_program_id || new_module_name;
107 }
108 
AddHloProto(uint64_t program_id,std::unique_ptr<const xla::HloProto> hlo_proto)109 void HloProtoMap::AddHloProto(uint64_t program_id,
110                               std::unique_ptr<const xla::HloProto> hlo_proto) {
111   if (AddHloProto(program_id, hlo_proto.get())) {
112     // Only add to <owned_hlo_protos_> if <hlo_proto> is new to HloProtoMap.
113     owned_hlo_protos_.push_back(std::move(hlo_proto));
114   }
115 }
116 
AddHloProtosFromXSpace(const XSpace & space)117 void HloProtoMap::AddHloProtosFromXSpace(const XSpace& space) {
118   for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(space)) {
119     AddHloProto(program_id, std::move(hlo_proto));
120   }
121 }
122 
GetModuleList() const123 std::vector<absl::string_view> HloProtoMap::GetModuleList() const {
124   std::vector<absl::string_view> module_list;
125   module_list.reserve(hlo_protos_by_name_.size());
126   for (const auto& [name, hlo_proto] : hlo_protos_by_name_) {
127     module_list.push_back(name);
128   }
129   return module_list;
130 }
131 
GetSortedModuleList() const132 std::vector<absl::string_view> HloProtoMap::GetSortedModuleList() const {
133   std::vector<absl::string_view> module_list = GetModuleList();
134   absl::c_sort(module_list);
135   return module_list;
136 }
137 
GetSortedModuleListByHeapTraceSize() const138 std::vector<absl::string_view> HloProtoMap::GetSortedModuleListByHeapTraceSize()
139     const {
140   std::vector<std::pair<absl::string_view, const xla::HloProto*>> hlo_protos(
141       hlo_protos_by_name_.begin(), hlo_protos_by_name_.end());
142 
143   // Sort the hlo protos by heap trace size and then by hlo module name.
144   // This way trivial computations will be on the bottom of the list.
145   absl::c_stable_sort(hlo_protos, [](const auto& a, const auto& b) {
146     int num_a = NumHeapSimulatorTraceEvents(a.second);
147     int num_b = NumHeapSimulatorTraceEvents(b.second);
148     return std::tie(num_a, b.first) > std::tie(num_b, a.first);
149   });
150 
151   std::vector<absl::string_view> module_list;
152   module_list.reserve(hlo_protos.size());
153   for (const auto& [name, hlo_proto] : hlo_protos) {
154     module_list.push_back(name);
155   }
156   return module_list;
157 }
158 
GetHloProtoByModuleName(absl::string_view module_name) const159 absl::StatusOr<const xla::HloProto*> HloProtoMap::GetHloProtoByModuleName(
160     absl::string_view module_name) const {
161   auto iter = hlo_protos_by_name_.find(module_name);
162   if (iter != hlo_protos_by_name_.end()) {
163     return iter->second;
164   }
165   return absl::NotFoundError(
166       absl::StrCat("Module name: ", module_name, " is not found."));
167 }
168 
169 }  // namespace profiler
170 }  // namespace tensorflow
171