1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/trace_processor/perfetto_sql/intrinsics/table_functions/dfs_weight_bounded.h"
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "perfetto/base/logging.h"
29 #include "perfetto/base/status.h"
30 #include "perfetto/ext/base/status_or.h"
31 #include "perfetto/protozero/proto_decoder.h"
32 #include "perfetto/trace_processor/basic_types.h"
33 #include "protos/perfetto/trace_processor/metrics_impl.pbzero.h"
34 #include "src/trace_processor/containers/string_pool.h"
35 #include "src/trace_processor/db/column.h"
36 #include "src/trace_processor/db/table.h"
37 #include "src/trace_processor/perfetto_sql/intrinsics/table_functions/tables_py.h"
38 #include "src/trace_processor/util/status_macros.h"
39 
40 namespace perfetto::trace_processor {
41 namespace tables {
42 DfsWeightBoundedTable::~DfsWeightBoundedTable() = default;
43 }  // namespace tables
44 
45 namespace {
46 struct Edge {
47   uint32_t id;
48   uint32_t weight;
49 };
50 using Destinations = std::vector<Edge>;
51 
ParseSourceToDestionationsMap(protos::pbzero::RepeatedBuilderResult::Decoder & source,protos::pbzero::RepeatedBuilderResult::Decoder & dest,protos::pbzero::RepeatedBuilderResult::Decoder & weight)52 base::StatusOr<std::vector<Destinations>> ParseSourceToDestionationsMap(
53     protos::pbzero::RepeatedBuilderResult::Decoder& source,
54     protos::pbzero::RepeatedBuilderResult::Decoder& dest,
55     protos::pbzero::RepeatedBuilderResult::Decoder& weight) {
56   std::vector<Destinations> source_to_destinations_map;
57   bool parse_error = false;
58   auto source_node_ids = source.int_values(&parse_error);
59   auto dest_node_ids = dest.int_values(&parse_error);
60   auto edge_weights = weight.int_values(&parse_error);
61 
62   for (; source_node_ids && dest_node_ids && edge_weights;
63        ++source_node_ids, ++dest_node_ids, ++edge_weights) {
64     source_to_destinations_map.resize(
65         std::max(source_to_destinations_map.size(),
66                  std::max(static_cast<size_t>(*source_node_ids + 1),
67                           static_cast<size_t>(*dest_node_ids + 1))));
68     source_to_destinations_map[static_cast<uint32_t>(*source_node_ids)]
69         .push_back(Edge{static_cast<uint32_t>(*dest_node_ids),
70                         static_cast<uint32_t>(*edge_weights)});
71   }
72   if (parse_error) {
73     return base::ErrStatus("Failed while parsing source or dest ids");
74   }
75   if (static_cast<bool>(source_node_ids) != static_cast<bool>(dest_node_ids)) {
76     return base::ErrStatus(
77         "dfs_weight_bounded: length of source and destination columns is not "
78         "the same");
79   }
80   return source_to_destinations_map;
81 }
82 
ParseRootToMaxWeightMap(protos::pbzero::RepeatedBuilderResult::Decoder & start,protos::pbzero::RepeatedBuilderResult::Decoder & end)83 base::StatusOr<std::vector<Edge>> ParseRootToMaxWeightMap(
84     protos::pbzero::RepeatedBuilderResult::Decoder& start,
85     protos::pbzero::RepeatedBuilderResult::Decoder& end) {
86   std::vector<Edge> roots;
87   bool parse_error = false;
88   auto root_node_ids = start.int_values(&parse_error);
89   auto target_weights = end.int_values(&parse_error);
90 
91   for (; root_node_ids && target_weights; ++root_node_ids, ++target_weights) {
92     roots.push_back(Edge{static_cast<uint32_t>(*root_node_ids),
93                          static_cast<uint32_t>(*target_weights)});
94   }
95 
96   if (parse_error) {
97     return base::ErrStatus(
98         "Failed while parsing root_node_ids or root_target_weights");
99   }
100   if (static_cast<bool>(root_node_ids) != static_cast<bool>(target_weights)) {
101     return base::ErrStatus(
102         "dfs_weight_bounded: length of root_node_ids and root_target_weights "
103         "columns is not the same");
104   }
105   return roots;
106 }
107 
DfsWeightBoundedImpl(tables::DfsWeightBoundedTable * table,const std::vector<Destinations> & source_to_destinations_map,const std::vector<Edge> & roots,const bool is_target_weight_floor)108 void DfsWeightBoundedImpl(
109     tables::DfsWeightBoundedTable* table,
110     const std::vector<Destinations>& source_to_destinations_map,
111     const std::vector<Edge>& roots,
112     const bool is_target_weight_floor) {
113   struct StackState {
114     uint32_t id;
115     uint32_t weight;
116     std::optional<uint32_t> parent_id;
117   };
118 
119   std::vector<uint8_t> seen_node_ids(source_to_destinations_map.size());
120   std::vector<StackState> stack;
121 
122   for (const auto& root : roots) {
123     stack.clear();
124     stack.push_back({root.id, 0, std::nullopt});
125     std::fill(seen_node_ids.begin(), seen_node_ids.end(), 0);
126 
127     for (uint32_t total_weight = 0; !stack.empty();) {
128       StackState stack_state = stack.back();
129       stack.pop_back();
130 
131       if (seen_node_ids[stack_state.id]) {
132         continue;
133       }
134       seen_node_ids[stack_state.id] = true;
135       total_weight += stack_state.weight;
136 
137       if (!is_target_weight_floor && total_weight > root.weight) {
138         // If target weight is a ceiling weight then we don't want to include
139         // the last node that crosses the threshold.
140         break;
141       }
142 
143       tables::DfsWeightBoundedTable::Row row;
144       row.root_node_id = root.id;
145       row.node_id = stack_state.id;
146       row.parent_node_id = stack_state.parent_id;
147       table->Insert(row);
148 
149       if (total_weight > root.weight) {
150         // If the target weight is a floor weight, we add the last node that
151         // crossed the threshold before exiting the search.
152         break;
153       }
154 
155       PERFETTO_DCHECK(stack_state.id < source_to_destinations_map.size());
156 
157       const auto& children = source_to_destinations_map[stack_state.id];
158       for (auto it = children.rbegin(); it != children.rend(); ++it) {
159         stack.emplace_back(StackState{(*it).id, (*it).weight, stack_state.id});
160       }
161     }
162   }
163 }
164 }  // namespace
165 
DfsWeightBounded(StringPool * pool)166 DfsWeightBounded::DfsWeightBounded(StringPool* pool) : pool_(pool) {}
167 DfsWeightBounded::~DfsWeightBounded() = default;
168 
CreateSchema()169 Table::Schema DfsWeightBounded::CreateSchema() {
170   return tables::DfsWeightBoundedTable::ComputeStaticSchema();
171 }
172 
TableName()173 std::string DfsWeightBounded::TableName() {
174   return tables::DfsWeightBoundedTable::Name();
175 }
176 
EstimateRowCount()177 uint32_t DfsWeightBounded::EstimateRowCount() {
178   // TODO(lalitm): improve this estimate.
179   return 1024;
180 }
181 
ComputeTable(const std::vector<SqlValue> & arguments)182 base::StatusOr<std::unique_ptr<Table>> DfsWeightBounded::ComputeTable(
183     const std::vector<SqlValue>& arguments) {
184   PERFETTO_CHECK(arguments.size() == 6);
185 
186   const SqlValue& raw_source_ids = arguments[0];
187   const SqlValue& raw_dest_ids = arguments[1];
188   const SqlValue& raw_edge_weights = arguments[2];
189   const SqlValue& raw_root_ids = arguments[3];
190   const SqlValue& raw_root_target_weights = arguments[4];
191   const SqlValue& raw_is_target_weight_floor = arguments[5];
192 
193   if (raw_source_ids.is_null() && raw_dest_ids.is_null() &&
194       raw_edge_weights.is_null()) {
195     return std::unique_ptr<Table>(
196         std::make_unique<tables::DfsWeightBoundedTable>(pool_));
197   }
198 
199   if (raw_root_ids.is_null() && raw_root_target_weights.is_null()) {
200     return std::unique_ptr<Table>(
201         std::make_unique<tables::DfsWeightBoundedTable>(pool_));
202   }
203 
204   if (raw_source_ids.is_null() || raw_dest_ids.is_null() ||
205       raw_edge_weights.is_null() || raw_root_ids.is_null() ||
206       raw_root_target_weights.is_null()) {
207     return base::ErrStatus(
208         "dfs_weight_bounded: either all arguments should be null or none "
209         "should be");
210   }
211   if (raw_source_ids.type != SqlValue::kBytes) {
212     return base::ErrStatus(
213         "dfs_weight_bounded: source_node_ids should be a repeated field");
214   }
215   if (raw_dest_ids.type != SqlValue::kBytes) {
216     return base::ErrStatus(
217         "dfs_weight_bounded: dest_node_ids should be a repeated field");
218   }
219   if (raw_edge_weights.type != SqlValue::kBytes) {
220     return base::ErrStatus(
221         "dfs_weight_bounded: edge_weights should be a repeated field");
222   }
223   if (raw_root_ids.type != SqlValue::kBytes) {
224     return base::ErrStatus(
225         "dfs_weight_bounded: root_ids should be a repeated field");
226   }
227   if (raw_root_target_weights.type != SqlValue::kBytes) {
228     return base::ErrStatus(
229         "dfs_weight_bounded: root_target_weights should be a repeated field");
230   }
231 
232   protos::pbzero::ProtoBuilderResult::Decoder proto_source_ids(
233       static_cast<const uint8_t*>(raw_source_ids.AsBytes()),
234       raw_source_ids.bytes_count);
235   if (!proto_source_ids.is_repeated()) {
236     return base::ErrStatus(
237         "dfs_weight_bounded: source_node_ids is not generated by RepeatedField "
238         "function");
239   }
240   protos::pbzero::RepeatedBuilderResult::Decoder source_ids(
241       proto_source_ids.repeated());
242 
243   protos::pbzero::ProtoBuilderResult::Decoder proto_dest_ids(
244       static_cast<const uint8_t*>(raw_dest_ids.AsBytes()),
245       raw_dest_ids.bytes_count);
246   if (!proto_dest_ids.is_repeated()) {
247     return base::ErrStatus(
248         "dfs_weight_bounded: dest_node_ids is not generated by RepeatedField "
249         "function");
250   }
251   protos::pbzero::RepeatedBuilderResult::Decoder dest_ids(
252       proto_dest_ids.repeated());
253 
254   protos::pbzero::ProtoBuilderResult::Decoder proto_edge_weights(
255       static_cast<const uint8_t*>(raw_edge_weights.AsBytes()),
256       raw_edge_weights.bytes_count);
257   if (!proto_edge_weights.is_repeated()) {
258     return base::ErrStatus(
259         "dfs_weight_bounded: edge_weights is not generated by RepeatedField "
260         "function");
261   }
262   protos::pbzero::RepeatedBuilderResult::Decoder edge_weights(
263       proto_edge_weights.repeated());
264 
265   protos::pbzero::ProtoBuilderResult::Decoder proto_root_ids(
266       static_cast<const uint8_t*>(raw_root_ids.AsBytes()),
267       raw_root_ids.bytes_count);
268   if (!proto_root_ids.is_repeated()) {
269     return base::ErrStatus(
270         "dfs_weight_bounded: root_ids is not generated by RepeatedField "
271         "function");
272   }
273   protos::pbzero::RepeatedBuilderResult::Decoder root_ids(
274       proto_root_ids.repeated());
275 
276   protos::pbzero::ProtoBuilderResult::Decoder proto_root_target_weights(
277       static_cast<const uint8_t*>(raw_root_target_weights.AsBytes()),
278       raw_root_target_weights.bytes_count);
279   if (!proto_root_target_weights.is_repeated()) {
280     return base::ErrStatus(
281         "dfs_weight_bounded: root_target_weights is not generated by "
282         "RepeatedField function");
283   }
284   protos::pbzero::RepeatedBuilderResult::Decoder root_target_weights(
285       proto_root_target_weights.repeated());
286 
287   bool is_target_weight_floor =
288       static_cast<bool>(raw_is_target_weight_floor.AsLong());
289   ASSIGN_OR_RETURN(auto map, ParseSourceToDestionationsMap(source_ids, dest_ids,
290                                                            edge_weights));
291 
292   ASSIGN_OR_RETURN(auto roots,
293                    ParseRootToMaxWeightMap(root_ids, root_target_weights));
294 
295   auto table = std::make_unique<tables::DfsWeightBoundedTable>(pool_);
296   DfsWeightBoundedImpl(table.get(), map, roots, is_target_weight_floor);
297   return std::unique_ptr<Table>(std::move(table));
298 }
299 
300 }  // namespace perfetto::trace_processor
301