xref: /aosp_15_r20/external/perfetto/src/trace_processor/perfetto_sql/intrinsics/functions/graph_scan.cc (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/trace_processor/perfetto_sql/intrinsics/functions/graph_scan.h"
18 
19 #include <algorithm>
20 #include <cinttypes>
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <string_view>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28 
29 #include "perfetto/base/logging.h"
30 #include "perfetto/base/status.h"
31 #include "perfetto/ext/base/status_or.h"
32 #include "perfetto/ext/base/string_utils.h"
33 #include "src/trace_processor/containers/string_pool.h"
34 #include "src/trace_processor/db/runtime_table.h"
35 #include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
36 #include "src/trace_processor/perfetto_sql/intrinsics/types/array.h"
37 #include "src/trace_processor/perfetto_sql/intrinsics/types/node.h"
38 #include "src/trace_processor/perfetto_sql/intrinsics/types/row_dataframe.h"
39 #include "src/trace_processor/perfetto_sql/intrinsics/types/value.h"
40 #include "src/trace_processor/perfetto_sql/parser/function_util.h"
41 #include "src/trace_processor/sqlite/bindings/sqlite_bind.h"
42 #include "src/trace_processor/sqlite/bindings/sqlite_column.h"
43 #include "src/trace_processor/sqlite/bindings/sqlite_function.h"
44 #include "src/trace_processor/sqlite/bindings/sqlite_result.h"
45 #include "src/trace_processor/sqlite/bindings/sqlite_stmt.h"
46 #include "src/trace_processor/sqlite/bindings/sqlite_type.h"
47 #include "src/trace_processor/sqlite/bindings/sqlite_value.h"
48 #include "src/trace_processor/sqlite/sql_source.h"
49 #include "src/trace_processor/sqlite/sqlite_engine.h"
50 #include "src/trace_processor/sqlite/sqlite_utils.h"
51 #include "src/trace_processor/util/status_macros.h"
52 
53 namespace perfetto::trace_processor {
54 namespace {
55 
InitToOutputAndStepTable(const perfetto_sql::RowDataframe & inits,const perfetto_sql::Graph & graph,RuntimeTable::Builder & step,uint32_t & step_row_count,RuntimeTable::Builder & out,uint32_t & out_row_count)56 base::Status InitToOutputAndStepTable(const perfetto_sql::RowDataframe& inits,
57                                       const perfetto_sql::Graph& graph,
58                                       RuntimeTable::Builder& step,
59                                       uint32_t& step_row_count,
60                                       RuntimeTable::Builder& out,
61                                       uint32_t& out_row_count) {
62   std::vector<uint32_t> empty_edges;
63   auto get_edges = [&](uint32_t id) {
64     return id < graph.size() ? graph[id].outgoing_edges : empty_edges;
65   };
66 
67   for (uint32_t i = 0; i < inits.size(); ++i) {
68     const auto* cell = inits.cells.data() + i * inits.column_names.size();
69     auto id = static_cast<uint32_t>(std::get<int64_t>(*cell));
70     RETURN_IF_ERROR(out.AddInteger(0, id));
71     out_row_count++;
72     for (uint32_t outgoing : get_edges(id)) {
73       step_row_count++;
74       RETURN_IF_ERROR(step.AddInteger(0, outgoing));
75     }
76     for (uint32_t j = 1; j < inits.column_names.size(); ++j) {
77       switch (cell[j].index()) {
78         case perfetto_sql::ValueIndex<std::monostate>():
79           RETURN_IF_ERROR(out.AddNull(j));
80           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
81             RETURN_IF_ERROR(step.AddNull(j));
82           }
83           break;
84         case perfetto_sql::ValueIndex<int64_t>(): {
85           int64_t r = std::get<int64_t>(cell[j]);
86           RETURN_IF_ERROR(out.AddInteger(j, r));
87           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
88             RETURN_IF_ERROR(step.AddInteger(j, r));
89           }
90           break;
91         }
92         case perfetto_sql::ValueIndex<double>(): {
93           double r = std::get<double>(cell[j]);
94           RETURN_IF_ERROR(out.AddFloat(j, r));
95           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
96             RETURN_IF_ERROR(step.AddFloat(j, r));
97           }
98           break;
99         }
100         case perfetto_sql::ValueIndex<std::string>(): {
101           const char* r = std::get<std::string>(cell[j]).c_str();
102           RETURN_IF_ERROR(out.AddText(j, r));
103           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
104             RETURN_IF_ERROR(step.AddText(j, r));
105           }
106           break;
107         }
108         default:
109           PERFETTO_FATAL("Invalid index");
110       }
111     }
112   }
113   return base::OkStatus();
114 }
115 
SqliteToOutputAndStepTable(SqliteEngine::PreparedStatement & stmt,const perfetto_sql::Graph & graph,RuntimeTable::Builder & step,uint32_t & step_row_count,RuntimeTable::Builder & out,uint32_t & out_row_count)116 base::Status SqliteToOutputAndStepTable(SqliteEngine::PreparedStatement& stmt,
117                                         const perfetto_sql::Graph& graph,
118                                         RuntimeTable::Builder& step,
119                                         uint32_t& step_row_count,
120                                         RuntimeTable::Builder& out,
121                                         uint32_t& out_row_count) {
122   std::vector<uint32_t> empty_edges;
123   auto get_edges = [&](uint32_t id) {
124     return id < graph.size() ? graph[id].outgoing_edges : empty_edges;
125   };
126 
127   uint32_t col_count = sqlite::column::Count(stmt.sqlite_stmt());
128   while (stmt.Step()) {
129     auto id =
130         static_cast<uint32_t>(sqlite::column::Int64(stmt.sqlite_stmt(), 0));
131     out_row_count++;
132     RETURN_IF_ERROR(out.AddInteger(0, id));
133     for (uint32_t outgoing : get_edges(id)) {
134       step_row_count++;
135       RETURN_IF_ERROR(step.AddInteger(0, outgoing));
136     }
137     for (uint32_t i = 1; i < col_count; ++i) {
138       switch (sqlite::column::Type(stmt.sqlite_stmt(), i)) {
139         case sqlite::Type::kNull:
140           RETURN_IF_ERROR(out.AddNull(i));
141           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
142             RETURN_IF_ERROR(step.AddNull(i));
143           }
144           break;
145         case sqlite::Type::kInteger: {
146           int64_t a = sqlite::column::Int64(stmt.sqlite_stmt(), i);
147           RETURN_IF_ERROR(out.AddInteger(i, a));
148           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
149             RETURN_IF_ERROR(step.AddInteger(i, a));
150           }
151           break;
152         }
153         case sqlite::Type::kText: {
154           const char* a = sqlite::column::Text(stmt.sqlite_stmt(), i);
155           RETURN_IF_ERROR(out.AddText(i, a));
156           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
157             RETURN_IF_ERROR(step.AddText(i, a));
158           }
159           break;
160         }
161         case sqlite::Type::kFloat: {
162           double a = sqlite::column::Double(stmt.sqlite_stmt(), i);
163           RETURN_IF_ERROR(out.AddFloat(i, a));
164           for ([[maybe_unused]] uint32_t _ : get_edges(id)) {
165             RETURN_IF_ERROR(step.AddFloat(i, a));
166           }
167           break;
168         }
169         case sqlite::Type::kBlob:
170           return base::ErrStatus("Unsupported blob type");
171       }
172     }
173   }
174   return stmt.status();
175 }
176 
PrepareStatement(PerfettoSqlEngine & engine,const std::vector<std::string> & cols,const std::string & sql)177 base::StatusOr<SqliteEngine::PreparedStatement> PrepareStatement(
178     PerfettoSqlEngine& engine,
179     const std::vector<std::string>& cols,
180     const std::string& sql) {
181   std::vector<std::string> select_cols;
182   std::vector<std::string> bind_cols;
183   for (uint32_t i = 0; i < cols.size(); ++i) {
184     select_cols.emplace_back(
185         base::StackString<1024>("c%" PRIu32 " as %s", i, cols[i].c_str())
186             .ToStdString());
187     bind_cols.emplace_back(base::StackString<1024>(
188                                "__intrinsic_table_ptr_bind(c%" PRIu32 ", '%s')",
189                                i, cols[i].c_str())
190                                .ToStdString());
191   }
192 
193   // TODO(lalitm): verify that the init aggregates line up correctly with the
194   // aggregation macro.
195   std::string raw_sql =
196       "(SELECT $cols FROM __intrinsic_table_ptr($var) WHERE $where)";
197   raw_sql = base::ReplaceAll(raw_sql, "$cols", base::Join(select_cols, ","));
198   raw_sql = base::ReplaceAll(raw_sql, "$where", base::Join(bind_cols, " AND "));
199   std::string res = base::ReplaceAll(sql, "$table", raw_sql);
200   return engine.PrepareSqliteStatement(
201       SqlSource::FromTraceProcessorImplementation("SELECT * FROM " + res));
202 }
203 
204 struct NodeState {
205   uint32_t depth = 0;
206   enum : uint8_t {
207     kUnvisited,
208     kWaitingForDescendants,
209     kDone,
210   } visit_state = kUnvisited;
211 };
212 
213 struct DepthTable {
214   RuntimeTable::Builder builder;
215   uint32_t row_count = 0;
216 };
217 
218 struct GraphAggregatingScanner {
219   base::StatusOr<std::unique_ptr<RuntimeTable>> Run();
220   std::vector<uint32_t> InitializeStateFromMaxNode();
221   uint32_t DfsAndComputeMaxDepth(std::vector<uint32_t> stack);
222   base::Status PushDownStartingAggregates(RuntimeTable::Builder& res,
223                                           uint32_t& res_row_count);
224   base::Status PushDownAggregates(SqliteEngine::PreparedStatement& agg_stmt,
225                                   uint32_t agg_col_count,
226                                   RuntimeTable::Builder& res,
227                                   uint32_t& res_row_count);
228 
GetEdgesperfetto::trace_processor::__anonea0ce5610111::GraphAggregatingScanner229   const std::vector<uint32_t>& GetEdges(uint32_t id) {
230     return id < graph.size() ? graph[id].outgoing_edges : empty_edges;
231   }
232 
233   PerfettoSqlEngine* engine;
234   StringPool* pool;
235   const perfetto_sql::Graph& graph;
236   const perfetto_sql::RowDataframe& inits;
237   std::string_view reduce;
238   std::vector<uint32_t> empty_edges;
239 
240   std::vector<NodeState> state;
241   std::vector<DepthTable> tables_per_depth;
242 };
243 
InitializeStateFromMaxNode()244 std::vector<uint32_t> GraphAggregatingScanner::InitializeStateFromMaxNode() {
245   std::vector<uint32_t> stack;
246   auto nodes_size = static_cast<uint32_t>(graph.size());
247   for (uint32_t i = 0; i < inits.size(); ++i) {
248     auto start_id = static_cast<uint32_t>(
249         std::get<int64_t>(inits.cells[i * inits.column_names.size()]));
250     nodes_size = std::max(nodes_size, static_cast<uint32_t>(start_id) + 1);
251     for (uint32_t dest : GetEdges(start_id)) {
252       stack.emplace_back(static_cast<uint32_t>(dest));
253     }
254   }
255   state = std::vector<NodeState>(nodes_size);
256   return stack;
257 }
258 
DfsAndComputeMaxDepth(std::vector<uint32_t> stack)259 uint32_t GraphAggregatingScanner::DfsAndComputeMaxDepth(
260     std::vector<uint32_t> stack) {
261   uint32_t max_depth = 0;
262   while (!stack.empty()) {
263     uint32_t source_id = stack.back();
264     NodeState& source = state[source_id];
265     switch (source.visit_state) {
266       case NodeState::kUnvisited:
267         source.visit_state = NodeState::kWaitingForDescendants;
268         for (uint32_t dest_id : GetEdges(source_id)) {
269           stack.push_back(dest_id);
270         }
271         break;
272       case NodeState::kWaitingForDescendants:
273         stack.pop_back();
274         source.visit_state = NodeState::kDone;
275         for (uint32_t dest_id : GetEdges(source_id)) {
276           PERFETTO_DCHECK(state[dest_id].visit_state == NodeState::kDone);
277           source.depth = std::max(state[dest_id].depth + 1, source.depth);
278         }
279         max_depth = std::max(max_depth, source.depth);
280         break;
281       case NodeState::kDone:
282         stack.pop_back();
283         break;
284     }
285   }
286   return max_depth;
287 }
288 
PushDownAggregates(SqliteEngine::PreparedStatement & agg_stmt,uint32_t agg_col_count,RuntimeTable::Builder & res,uint32_t & res_row_count)289 base::Status GraphAggregatingScanner::PushDownAggregates(
290     SqliteEngine::PreparedStatement& agg_stmt,
291     uint32_t agg_col_count,
292     RuntimeTable::Builder& res,
293     uint32_t& res_row_count) {
294   while (agg_stmt.Step()) {
295     auto id =
296         static_cast<uint32_t>(sqlite::column::Int64(agg_stmt.sqlite_stmt(), 0));
297     res_row_count++;
298     RETURN_IF_ERROR(res.AddInteger(0, id));
299     for (uint32_t outgoing : GetEdges(id)) {
300       auto& dt = tables_per_depth[state[outgoing].depth];
301       dt.row_count++;
302       RETURN_IF_ERROR(dt.builder.AddInteger(0, outgoing));
303     }
304     for (uint32_t i = 1; i < agg_col_count; ++i) {
305       switch (sqlite::column::Type(agg_stmt.sqlite_stmt(), i)) {
306         case sqlite::Type::kNull:
307           RETURN_IF_ERROR(res.AddNull(i));
308           for (uint32_t outgoing : GetEdges(id)) {
309             auto& dt = tables_per_depth[state[outgoing].depth];
310             RETURN_IF_ERROR(dt.builder.AddNull(i));
311           }
312           break;
313         case sqlite::Type::kInteger: {
314           int64_t a = sqlite::column::Int64(agg_stmt.sqlite_stmt(), i);
315           RETURN_IF_ERROR(res.AddInteger(i, a));
316           for (uint32_t outgoing : GetEdges(id)) {
317             auto& dt = tables_per_depth[state[outgoing].depth];
318             RETURN_IF_ERROR(dt.builder.AddInteger(i, a));
319           }
320           break;
321         }
322         case sqlite::Type::kText: {
323           const char* a = sqlite::column::Text(agg_stmt.sqlite_stmt(), i);
324           RETURN_IF_ERROR(res.AddText(i, a));
325           for (uint32_t outgoing : GetEdges(id)) {
326             auto& dt = tables_per_depth[state[outgoing].depth];
327             RETURN_IF_ERROR(dt.builder.AddText(i, a));
328           }
329           break;
330         }
331         case sqlite::Type::kFloat: {
332           double a = sqlite::column::Double(agg_stmt.sqlite_stmt(), i);
333           RETURN_IF_ERROR(res.AddFloat(i, a));
334           for (uint32_t outgoing : GetEdges(id)) {
335             auto& dt = tables_per_depth[state[outgoing].depth];
336             RETURN_IF_ERROR(dt.builder.AddFloat(i, a));
337           }
338           break;
339         }
340         case sqlite::Type::kBlob:
341           return base::ErrStatus("Unsupported blob type");
342       }
343     }
344   }
345   return agg_stmt.status();
346 }
347 
PushDownStartingAggregates(RuntimeTable::Builder & res,uint32_t & res_row_count)348 base::Status GraphAggregatingScanner::PushDownStartingAggregates(
349     RuntimeTable::Builder& res,
350     uint32_t& res_row_count) {
351   for (uint32_t i = 0; i < inits.size(); ++i) {
352     const auto* cell = inits.cells.data() + i * inits.column_names.size();
353     auto id = static_cast<uint32_t>(std::get<int64_t>(*cell));
354     RETURN_IF_ERROR(res.AddInteger(0, id));
355     res_row_count++;
356     for (uint32_t outgoing : GetEdges(id)) {
357       auto& dt = tables_per_depth[state[outgoing].depth];
358       dt.row_count++;
359       RETURN_IF_ERROR(dt.builder.AddInteger(0, outgoing));
360     }
361     for (uint32_t j = 1; j < inits.column_names.size(); ++j) {
362       switch (cell[j].index()) {
363         case perfetto_sql::ValueIndex<std::monostate>():
364           RETURN_IF_ERROR(res.AddNull(j));
365           for (uint32_t outgoing : GetEdges(id)) {
366             auto& dt = tables_per_depth[state[outgoing].depth];
367             RETURN_IF_ERROR(dt.builder.AddNull(j));
368           }
369           break;
370         case perfetto_sql::ValueIndex<int64_t>(): {
371           int64_t r = std::get<int64_t>(cell[j]);
372           RETURN_IF_ERROR(res.AddInteger(j, r));
373           for (uint32_t outgoing : GetEdges(id)) {
374             auto& dt = tables_per_depth[state[outgoing].depth];
375             RETURN_IF_ERROR(dt.builder.AddInteger(j, r));
376           }
377           break;
378         }
379         case perfetto_sql::ValueIndex<double>(): {
380           double r = std::get<double>(cell[j]);
381           RETURN_IF_ERROR(res.AddFloat(j, r));
382           for (uint32_t outgoing : GetEdges(id)) {
383             auto& dt = tables_per_depth[state[outgoing].depth];
384             RETURN_IF_ERROR(dt.builder.AddFloat(j, r));
385           }
386           break;
387         }
388         case perfetto_sql::ValueIndex<std::string>(): {
389           const char* r = std::get<std::string>(cell[j]).c_str();
390           RETURN_IF_ERROR(res.AddText(j, r));
391           for (uint32_t outgoing : GetEdges(id)) {
392             auto& dt = tables_per_depth[state[outgoing].depth];
393             RETURN_IF_ERROR(dt.builder.AddText(j, r));
394           }
395           break;
396         }
397         default:
398           PERFETTO_FATAL("Invalid index");
399       }
400     }
401   }
402   return base::OkStatus();
403 }
404 
Run()405 base::StatusOr<std::unique_ptr<RuntimeTable>> GraphAggregatingScanner::Run() {
406   if (!inits.id_column_index) {
407     return base::ErrStatus(
408         "graph_aggregating_scan: 'id' column is not present in initial nodes "
409         "table");
410   }
411   if (inits.id_column_index != 0) {
412     return base::ErrStatus(
413         "graph_aggregating_scan: 'id' column must be the first column in the "
414         "initial nodes table");
415   }
416 
417   // The basic idea of this algorithm is as follows:
418   // 1) Setup the state vector by figuring out the maximum id in the initial and
419   //    graph tables.
420   // 2) Do a DFS to compute the depth of each node and figure out the max depth.
421   // 3) Setup all the table builders for each depth.
422   // 4) For all the starting nodes, push down their values to their dependents
423   //    and also store the aggregates in the final result table.
424   // 5) Going from highest depth downward, run the aggregation SQL the user
425   //    specified, push down those values to their dependents and also store the
426   //    aggregates in the final result table.
427   // 6) Return the final result table.
428   //
429   // The complexity of this algorithm is O(n) in both memory and CPU.
430   //
431   // TODO(lalitm): there is a significant optimization we can do here: instead
432   // of pulling the data from SQL to C++ and then feeding that to the runtime
433   // table builder, we could just have an aggregate function which directly
434   // writes into the table itself. This would be better because:
435   //   1) It would be faster
436   //   2) It would remove the need for first creating a row dataframe and then a
437   //      table builder for the initial nodes
438   //   3) It would allow code deduplication between the initial query, the step
439   //      query and also CREATE PERFETTO TABLE: the code here is very similar to
440   //      the code in PerfettoSqlEngine.
441 
442   RuntimeTable::Builder res(pool, inits.column_names);
443   uint32_t res_row_count = 0;
444   uint32_t max_depth = DfsAndComputeMaxDepth(InitializeStateFromMaxNode());
445 
446   for (uint32_t i = 0; i < max_depth + 1; ++i) {
447     tables_per_depth.emplace_back(
448         DepthTable{RuntimeTable::Builder(pool, inits.column_names), 0});
449   }
450 
451   RETURN_IF_ERROR(PushDownStartingAggregates(res, res_row_count));
452   ASSIGN_OR_RETURN(auto agg_stmt, PrepareStatement(*engine, inits.column_names,
453                                                    std::string(reduce)));
454   RETURN_IF_ERROR(agg_stmt.status());
455 
456   uint32_t agg_col_count = sqlite::column::Count(agg_stmt.sqlite_stmt());
457   std::vector<std::string> aggregate_cols;
458   aggregate_cols.reserve(agg_col_count);
459   for (uint32_t i = 0; i < agg_col_count; ++i) {
460     aggregate_cols.emplace_back(
461         sqlite::column::Name(agg_stmt.sqlite_stmt(), i));
462   }
463 
464   if (aggregate_cols != inits.column_names) {
465     return base::ErrStatus(
466         "graph_scan: aggregate SQL columns do not match init columns");
467   }
468 
469   for (auto i = static_cast<int64_t>(tables_per_depth.size() - 1); i >= 0;
470        --i) {
471     int err = sqlite::stmt::Reset(agg_stmt.sqlite_stmt());
472     if (err != SQLITE_OK) {
473       return base::ErrStatus("Failed to reset statement");
474     }
475     auto idx = static_cast<uint32_t>(i);
476     ASSIGN_OR_RETURN(auto depth_tab,
477                      std::move(tables_per_depth[idx].builder)
478                          .Build(tables_per_depth[idx].row_count));
479     err = sqlite::bind::Pointer(
480         agg_stmt.sqlite_stmt(), 1, depth_tab.release(), "TABLE", [](void* tab) {
481           std::unique_ptr<RuntimeTable>(static_cast<RuntimeTable*>(tab));
482         });
483     if (err != SQLITE_OK) {
484       return base::ErrStatus("Failed to bind pointer %d", err);
485     }
486     RETURN_IF_ERROR(
487         PushDownAggregates(agg_stmt, agg_col_count, res, res_row_count));
488   }
489   return std::move(res).Build(res_row_count);
490 }
491 
492 struct GraphAggregatingScan : public SqliteFunction<GraphAggregatingScan> {
493   static constexpr char kName[] = "__intrinsic_graph_aggregating_scan";
494   static constexpr int kArgCount = 4;
495   struct UserDataContext {
496     PerfettoSqlEngine* engine;
497     StringPool* pool;
498   };
499 
Stepperfetto::trace_processor::__anonea0ce5610111::GraphAggregatingScan500   static void Step(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
501     PERFETTO_DCHECK(argc == kArgCount);
502 
503     auto* user_data = GetUserData(ctx);
504     const char* reduce = sqlite::value::Text(argv[2]);
505     if (!reduce) {
506       return sqlite::result::Error(
507           ctx, "graph_aggregating_scan: aggegate SQL cannot be null");
508     }
509     const char* column_list = sqlite::value::Text(argv[3]);
510     if (!column_list) {
511       return sqlite::result::Error(
512           ctx, "graph_aggregating_scan: column list cannot be null");
513     }
514 
515     std::vector<std::string> col_names{"id"};
516     for (const auto& c :
517          base::SplitString(base::StripChars(column_list, "()", ' '), ",")) {
518       col_names.push_back(base::TrimWhitespace(c));
519     }
520 
521     const auto* init = sqlite::value::Pointer<perfetto_sql::RowDataframe>(
522         argv[1], "ROW_DATAFRAME");
523     if (!init) {
524       SQLITE_ASSIGN_OR_RETURN(
525           ctx, auto table,
526           RuntimeTable::Builder(user_data->pool, col_names).Build(0));
527       return sqlite::result::UniquePointer(ctx, std::move(table), "TABLE");
528     }
529     if (col_names != init->column_names) {
530       return sqlite::result::Error(
531           ctx, base::StackString<1024>(
532                    "graph_aggregating_scan: column list '%s' does not match "
533                    "initial table list '%s'",
534                    base::Join(col_names, ",").c_str(),
535                    base::Join(init->column_names, ",").c_str())
536                    .c_str());
537     }
538 
539     const auto* nodes =
540         sqlite::value::Pointer<perfetto_sql::Graph>(argv[0], "GRAPH");
541     GraphAggregatingScanner scanner{
542         user_data->engine,
543         user_data->pool,
544         nodes ? *nodes : perfetto_sql::Graph(),
545         *init,
546         reduce,
547         {},
548         {},
549         {},
550     };
551     auto result = scanner.Run();
552     if (!result.ok()) {
553       return sqlite::utils::SetError(ctx, result.status());
554     }
555     return sqlite::result::UniquePointer(ctx, std::move(*result), "TABLE");
556   }
557 };
558 
559 struct GraphScan : public SqliteFunction<GraphScan> {
560   static constexpr char kName[] = "__intrinsic_graph_scan";
561   static constexpr int kArgCount = 4;
562   struct UserDataContext {
563     PerfettoSqlEngine* engine;
564     StringPool* pool;
565   };
566 
Stepperfetto::trace_processor::__anonea0ce5610111::GraphScan567   static void Step(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
568     PERFETTO_DCHECK(argc == kArgCount);
569 
570     auto* user_data = GetUserData(ctx);
571     const char* step_sql = sqlite::value::Text(argv[2]);
572     if (!step_sql) {
573       return sqlite::result::Error(ctx, "graph_scan: step SQL cannot be null");
574     }
575     const char* column_list = sqlite::value::Text(argv[3]);
576     if (!column_list) {
577       return sqlite::result::Error(ctx,
578                                    "graph_scan: column list cannot be null");
579     }
580 
581     std::vector<std::string> col_names{"id"};
582     for (const auto& c :
583          base::SplitString(base::StripChars(column_list, "()", ' '), ",")) {
584       col_names.push_back(base::TrimWhitespace(c));
585     }
586 
587     const auto* init = sqlite::value::Pointer<perfetto_sql::RowDataframe>(
588         argv[1], "ROW_DATAFRAME");
589     if (!init) {
590       SQLITE_ASSIGN_OR_RETURN(
591           ctx, auto table,
592           RuntimeTable::Builder(user_data->pool, col_names).Build(0));
593       return sqlite::result::UniquePointer(ctx, std::move(table), "TABLE");
594     }
595     if (col_names != init->column_names) {
596       base::StackString<1024> errmsg(
597           "graph_scan: column list '%s' does not match initial table list '%s'",
598           base::Join(col_names, ",").c_str(),
599           base::Join(init->column_names, ",").c_str());
600       return sqlite::result::Error(ctx, errmsg.c_str());
601     }
602 
603     const auto* raw_graph =
604         sqlite::value::Pointer<perfetto_sql::Graph>(argv[0], "GRAPH");
605     const auto& graph = raw_graph ? *raw_graph : perfetto_sql::Graph();
606 
607     RuntimeTable::Builder out(user_data->pool, init->column_names);
608     uint32_t out_count = 0;
609 
610     std::unique_ptr<RuntimeTable> step_table;
611     {
612       RuntimeTable::Builder step(user_data->pool, init->column_names);
613       uint32_t step_count = 0;
614       SQLITE_RETURN_IF_ERROR(
615           ctx, InitToOutputAndStepTable(*init, graph, step, step_count, out,
616                                         out_count));
617       SQLITE_ASSIGN_OR_RETURN(ctx, step_table,
618                               std::move(step).Build(step_count));
619     }
620     SQLITE_ASSIGN_OR_RETURN(
621         ctx, auto agg_stmt,
622         PrepareStatement(*user_data->engine, init->column_names, step_sql));
623     while (step_table->row_count() > 0) {
624       int err = sqlite::stmt::Reset(agg_stmt.sqlite_stmt());
625       if (err != SQLITE_OK) {
626         return sqlite::utils::SetError(ctx, "Failed to reset statement");
627       }
628       err = sqlite::bind::UniquePointer(agg_stmt.sqlite_stmt(), 1,
629                                         std::move(step_table), "TABLE");
630       if (err != SQLITE_OK) {
631         return sqlite::utils::SetError(
632             ctx,
633             base::StackString<1024>("Failed to bind pointer %d", err).c_str());
634       }
635 
636       RuntimeTable::Builder step(user_data->pool, init->column_names);
637       uint32_t step_count = 0;
638       SQLITE_RETURN_IF_ERROR(
639           ctx, SqliteToOutputAndStepTable(agg_stmt, graph, step, step_count,
640                                           out, out_count));
641       SQLITE_ASSIGN_OR_RETURN(ctx, step_table,
642                               std::move(step).Build(step_count));
643     }
644     SQLITE_ASSIGN_OR_RETURN(ctx, auto res, std::move(out).Build(out_count));
645     return sqlite::result::UniquePointer(ctx, std::move(res), "TABLE");
646   }
647 };
648 
649 }  // namespace
650 
RegisterGraphScanFunctions(PerfettoSqlEngine & engine,StringPool * pool)651 base::Status RegisterGraphScanFunctions(PerfettoSqlEngine& engine,
652                                         StringPool* pool) {
653   RETURN_IF_ERROR(engine.RegisterSqliteFunction<GraphScan>(
654       std::make_unique<GraphScan::UserDataContext>(
655           GraphScan::UserDataContext{&engine, pool})));
656   return engine.RegisterSqliteFunction<GraphAggregatingScan>(
657       std::make_unique<GraphAggregatingScan::UserDataContext>(
658           GraphAggregatingScan::UserDataContext{&engine, pool}));
659 }
660 
661 }  // namespace perfetto::trace_processor
662