1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/trace_processor/perfetto_sql/intrinsics/table_functions/dfs_weight_bounded.h"
18
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "perfetto/base/logging.h"
29 #include "perfetto/base/status.h"
30 #include "perfetto/ext/base/status_or.h"
31 #include "perfetto/protozero/proto_decoder.h"
32 #include "perfetto/trace_processor/basic_types.h"
33 #include "protos/perfetto/trace_processor/metrics_impl.pbzero.h"
34 #include "src/trace_processor/containers/string_pool.h"
35 #include "src/trace_processor/db/column.h"
36 #include "src/trace_processor/db/table.h"
37 #include "src/trace_processor/perfetto_sql/intrinsics/table_functions/tables_py.h"
38 #include "src/trace_processor/util/status_macros.h"
39
40 namespace perfetto::trace_processor {
41 namespace tables {
42 DfsWeightBoundedTable::~DfsWeightBoundedTable() = default;
43 } // namespace tables
44
45 namespace {
46 struct Edge {
47 uint32_t id;
48 uint32_t weight;
49 };
50 using Destinations = std::vector<Edge>;
51
ParseSourceToDestionationsMap(protos::pbzero::RepeatedBuilderResult::Decoder & source,protos::pbzero::RepeatedBuilderResult::Decoder & dest,protos::pbzero::RepeatedBuilderResult::Decoder & weight)52 base::StatusOr<std::vector<Destinations>> ParseSourceToDestionationsMap(
53 protos::pbzero::RepeatedBuilderResult::Decoder& source,
54 protos::pbzero::RepeatedBuilderResult::Decoder& dest,
55 protos::pbzero::RepeatedBuilderResult::Decoder& weight) {
56 std::vector<Destinations> source_to_destinations_map;
57 bool parse_error = false;
58 auto source_node_ids = source.int_values(&parse_error);
59 auto dest_node_ids = dest.int_values(&parse_error);
60 auto edge_weights = weight.int_values(&parse_error);
61
62 for (; source_node_ids && dest_node_ids && edge_weights;
63 ++source_node_ids, ++dest_node_ids, ++edge_weights) {
64 source_to_destinations_map.resize(
65 std::max(source_to_destinations_map.size(),
66 std::max(static_cast<size_t>(*source_node_ids + 1),
67 static_cast<size_t>(*dest_node_ids + 1))));
68 source_to_destinations_map[static_cast<uint32_t>(*source_node_ids)]
69 .push_back(Edge{static_cast<uint32_t>(*dest_node_ids),
70 static_cast<uint32_t>(*edge_weights)});
71 }
72 if (parse_error) {
73 return base::ErrStatus("Failed while parsing source or dest ids");
74 }
75 if (static_cast<bool>(source_node_ids) != static_cast<bool>(dest_node_ids)) {
76 return base::ErrStatus(
77 "dfs_weight_bounded: length of source and destination columns is not "
78 "the same");
79 }
80 return source_to_destinations_map;
81 }
82
ParseRootToMaxWeightMap(protos::pbzero::RepeatedBuilderResult::Decoder & start,protos::pbzero::RepeatedBuilderResult::Decoder & end)83 base::StatusOr<std::vector<Edge>> ParseRootToMaxWeightMap(
84 protos::pbzero::RepeatedBuilderResult::Decoder& start,
85 protos::pbzero::RepeatedBuilderResult::Decoder& end) {
86 std::vector<Edge> roots;
87 bool parse_error = false;
88 auto root_node_ids = start.int_values(&parse_error);
89 auto target_weights = end.int_values(&parse_error);
90
91 for (; root_node_ids && target_weights; ++root_node_ids, ++target_weights) {
92 roots.push_back(Edge{static_cast<uint32_t>(*root_node_ids),
93 static_cast<uint32_t>(*target_weights)});
94 }
95
96 if (parse_error) {
97 return base::ErrStatus(
98 "Failed while parsing root_node_ids or root_target_weights");
99 }
100 if (static_cast<bool>(root_node_ids) != static_cast<bool>(target_weights)) {
101 return base::ErrStatus(
102 "dfs_weight_bounded: length of root_node_ids and root_target_weights "
103 "columns is not the same");
104 }
105 return roots;
106 }
107
DfsWeightBoundedImpl(tables::DfsWeightBoundedTable * table,const std::vector<Destinations> & source_to_destinations_map,const std::vector<Edge> & roots,const bool is_target_weight_floor)108 void DfsWeightBoundedImpl(
109 tables::DfsWeightBoundedTable* table,
110 const std::vector<Destinations>& source_to_destinations_map,
111 const std::vector<Edge>& roots,
112 const bool is_target_weight_floor) {
113 struct StackState {
114 uint32_t id;
115 uint32_t weight;
116 std::optional<uint32_t> parent_id;
117 };
118
119 std::vector<uint8_t> seen_node_ids(source_to_destinations_map.size());
120 std::vector<StackState> stack;
121
122 for (const auto& root : roots) {
123 stack.clear();
124 stack.push_back({root.id, 0, std::nullopt});
125 std::fill(seen_node_ids.begin(), seen_node_ids.end(), 0);
126
127 for (uint32_t total_weight = 0; !stack.empty();) {
128 StackState stack_state = stack.back();
129 stack.pop_back();
130
131 if (seen_node_ids[stack_state.id]) {
132 continue;
133 }
134 seen_node_ids[stack_state.id] = true;
135 total_weight += stack_state.weight;
136
137 if (!is_target_weight_floor && total_weight > root.weight) {
138 // If target weight is a ceiling weight then we don't want to include
139 // the last node that crosses the threshold.
140 break;
141 }
142
143 tables::DfsWeightBoundedTable::Row row;
144 row.root_node_id = root.id;
145 row.node_id = stack_state.id;
146 row.parent_node_id = stack_state.parent_id;
147 table->Insert(row);
148
149 if (total_weight > root.weight) {
150 // If the target weight is a floor weight, we add the last node that
151 // crossed the threshold before exiting the search.
152 break;
153 }
154
155 PERFETTO_DCHECK(stack_state.id < source_to_destinations_map.size());
156
157 const auto& children = source_to_destinations_map[stack_state.id];
158 for (auto it = children.rbegin(); it != children.rend(); ++it) {
159 stack.emplace_back(StackState{(*it).id, (*it).weight, stack_state.id});
160 }
161 }
162 }
163 }
164 } // namespace
165
DfsWeightBounded(StringPool * pool)166 DfsWeightBounded::DfsWeightBounded(StringPool* pool) : pool_(pool) {}
167 DfsWeightBounded::~DfsWeightBounded() = default;
168
CreateSchema()169 Table::Schema DfsWeightBounded::CreateSchema() {
170 return tables::DfsWeightBoundedTable::ComputeStaticSchema();
171 }
172
TableName()173 std::string DfsWeightBounded::TableName() {
174 return tables::DfsWeightBoundedTable::Name();
175 }
176
EstimateRowCount()177 uint32_t DfsWeightBounded::EstimateRowCount() {
178 // TODO(lalitm): improve this estimate.
179 return 1024;
180 }
181
ComputeTable(const std::vector<SqlValue> & arguments)182 base::StatusOr<std::unique_ptr<Table>> DfsWeightBounded::ComputeTable(
183 const std::vector<SqlValue>& arguments) {
184 PERFETTO_CHECK(arguments.size() == 6);
185
186 const SqlValue& raw_source_ids = arguments[0];
187 const SqlValue& raw_dest_ids = arguments[1];
188 const SqlValue& raw_edge_weights = arguments[2];
189 const SqlValue& raw_root_ids = arguments[3];
190 const SqlValue& raw_root_target_weights = arguments[4];
191 const SqlValue& raw_is_target_weight_floor = arguments[5];
192
193 if (raw_source_ids.is_null() && raw_dest_ids.is_null() &&
194 raw_edge_weights.is_null()) {
195 return std::unique_ptr<Table>(
196 std::make_unique<tables::DfsWeightBoundedTable>(pool_));
197 }
198
199 if (raw_root_ids.is_null() && raw_root_target_weights.is_null()) {
200 return std::unique_ptr<Table>(
201 std::make_unique<tables::DfsWeightBoundedTable>(pool_));
202 }
203
204 if (raw_source_ids.is_null() || raw_dest_ids.is_null() ||
205 raw_edge_weights.is_null() || raw_root_ids.is_null() ||
206 raw_root_target_weights.is_null()) {
207 return base::ErrStatus(
208 "dfs_weight_bounded: either all arguments should be null or none "
209 "should be");
210 }
211 if (raw_source_ids.type != SqlValue::kBytes) {
212 return base::ErrStatus(
213 "dfs_weight_bounded: source_node_ids should be a repeated field");
214 }
215 if (raw_dest_ids.type != SqlValue::kBytes) {
216 return base::ErrStatus(
217 "dfs_weight_bounded: dest_node_ids should be a repeated field");
218 }
219 if (raw_edge_weights.type != SqlValue::kBytes) {
220 return base::ErrStatus(
221 "dfs_weight_bounded: edge_weights should be a repeated field");
222 }
223 if (raw_root_ids.type != SqlValue::kBytes) {
224 return base::ErrStatus(
225 "dfs_weight_bounded: root_ids should be a repeated field");
226 }
227 if (raw_root_target_weights.type != SqlValue::kBytes) {
228 return base::ErrStatus(
229 "dfs_weight_bounded: root_target_weights should be a repeated field");
230 }
231
232 protos::pbzero::ProtoBuilderResult::Decoder proto_source_ids(
233 static_cast<const uint8_t*>(raw_source_ids.AsBytes()),
234 raw_source_ids.bytes_count);
235 if (!proto_source_ids.is_repeated()) {
236 return base::ErrStatus(
237 "dfs_weight_bounded: source_node_ids is not generated by RepeatedField "
238 "function");
239 }
240 protos::pbzero::RepeatedBuilderResult::Decoder source_ids(
241 proto_source_ids.repeated());
242
243 protos::pbzero::ProtoBuilderResult::Decoder proto_dest_ids(
244 static_cast<const uint8_t*>(raw_dest_ids.AsBytes()),
245 raw_dest_ids.bytes_count);
246 if (!proto_dest_ids.is_repeated()) {
247 return base::ErrStatus(
248 "dfs_weight_bounded: dest_node_ids is not generated by RepeatedField "
249 "function");
250 }
251 protos::pbzero::RepeatedBuilderResult::Decoder dest_ids(
252 proto_dest_ids.repeated());
253
254 protos::pbzero::ProtoBuilderResult::Decoder proto_edge_weights(
255 static_cast<const uint8_t*>(raw_edge_weights.AsBytes()),
256 raw_edge_weights.bytes_count);
257 if (!proto_edge_weights.is_repeated()) {
258 return base::ErrStatus(
259 "dfs_weight_bounded: edge_weights is not generated by RepeatedField "
260 "function");
261 }
262 protos::pbzero::RepeatedBuilderResult::Decoder edge_weights(
263 proto_edge_weights.repeated());
264
265 protos::pbzero::ProtoBuilderResult::Decoder proto_root_ids(
266 static_cast<const uint8_t*>(raw_root_ids.AsBytes()),
267 raw_root_ids.bytes_count);
268 if (!proto_root_ids.is_repeated()) {
269 return base::ErrStatus(
270 "dfs_weight_bounded: root_ids is not generated by RepeatedField "
271 "function");
272 }
273 protos::pbzero::RepeatedBuilderResult::Decoder root_ids(
274 proto_root_ids.repeated());
275
276 protos::pbzero::ProtoBuilderResult::Decoder proto_root_target_weights(
277 static_cast<const uint8_t*>(raw_root_target_weights.AsBytes()),
278 raw_root_target_weights.bytes_count);
279 if (!proto_root_target_weights.is_repeated()) {
280 return base::ErrStatus(
281 "dfs_weight_bounded: root_target_weights is not generated by "
282 "RepeatedField function");
283 }
284 protos::pbzero::RepeatedBuilderResult::Decoder root_target_weights(
285 proto_root_target_weights.repeated());
286
287 bool is_target_weight_floor =
288 static_cast<bool>(raw_is_target_weight_floor.AsLong());
289 ASSIGN_OR_RETURN(auto map, ParseSourceToDestionationsMap(source_ids, dest_ids,
290 edge_weights));
291
292 ASSIGN_OR_RETURN(auto roots,
293 ParseRootToMaxWeightMap(root_ids, root_target_weights));
294
295 auto table = std::make_unique<tables::DfsWeightBoundedTable>(pool_);
296 DfsWeightBoundedImpl(table.get(), map, roots, is_target_weight_floor);
297 return std::unique_ptr<Table>(std::move(table));
298 }
299
300 } // namespace perfetto::trace_processor
301