xref: /aosp_15_r20/external/perfetto/python/test/api_integrationtest.py (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1#!/usr/bin/env python3
2# Copyright (C) 2020 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import io
17import os
18import tempfile
19import unittest
20from typing import Optional
21
22import pandas as pd
23
24from perfetto.batch_trace_processor.api import BatchTraceProcessor
25from perfetto.batch_trace_processor.api import BatchTraceProcessorConfig
26from perfetto.batch_trace_processor.api import FailureHandling
27from perfetto.batch_trace_processor.api import Metadata
28from perfetto.batch_trace_processor.api import TraceListReference
29from perfetto.trace_processor.api import PLATFORM_DELEGATE
30from perfetto.trace_processor.api import TraceProcessor
31from perfetto.trace_processor.api import TraceProcessorException
32from perfetto.trace_processor.api import TraceProcessorConfig
33from perfetto.trace_processor.api import TraceReference
34from perfetto.trace_uri_resolver.resolver import TraceUriResolver
35from perfetto.trace_uri_resolver.path import PathUriResolver
36
37
38class SimpleResolver(TraceUriResolver):
39  PREFIX = 'simple'
40
41  def __init__(self, path, skip_resolve_file=False):
42    self.path = path
43    self.file = open(example_android_trace_path(), 'rb')
44    self.skip_resolve_file = skip_resolve_file
45
46  def file_gen(self):
47    with open(example_android_trace_path(), 'rb') as f:
48      yield f.read()
49
50  def resolve(self):
51    res = [
52        TraceUriResolver.Result(
53            self.file_gen(), metadata={'source': 'generator'}),
54        TraceUriResolver.Result(
55            example_android_trace_path(), metadata={'source': 'path'}),
56    ]
57    if not self.skip_resolve_file:
58      res.extend([
59          TraceUriResolver.Result(
60              PathUriResolver(example_android_trace_path()),
61              metadata={'source': 'path_resolver'}),
62          TraceUriResolver.Result(self.file, metadata={'source': 'file'}),
63      ])
64    return res
65
66
67class RecursiveResolver(SimpleResolver):
68  PREFIX = 'recursive'
69
70  def __init__(self, path, skip_resolve_file):
71    super().__init__(path=path, skip_resolve_file=skip_resolve_file)
72
73  def resolve(self):
74    srf = self.skip_resolve_file
75    return [
76        TraceUriResolver.Result(
77            self.file_gen(), metadata={'source': 'recursive_gen'}),
78        TraceUriResolver.Result(
79            f'simple:path={self.path};skip_resolve_file={srf}',
80            metadata={
81                'source': 'recursive_path',
82                'root_source': 'recursive_path'
83            }),
84        TraceUriResolver.Result(
85            SimpleResolver(
86                path=self.path, skip_resolve_file=self.skip_resolve_file),
87            metadata={
88                'source': 'recursive_obj',
89                'root_source': 'recursive_obj'
90            }),
91    ]
92
93
94class SimpleObserver(BatchTraceProcessor.Observer):
95
96  def __init__(self):
97    self.execution_times = []
98
99  def trace_processed(self, metadata: Metadata, execution_time_seconds: float):
100    self.execution_times.append(execution_time_seconds)
101
102
103def create_batch_tp(
104    traces: TraceListReference,
105    load_failure_handling: FailureHandling = FailureHandling.RAISE_EXCEPTION,
106    execute_failure_handling: FailureHandling = FailureHandling.RAISE_EXCEPTION,
107    observer: Optional[BatchTraceProcessor.Observer] = None):
108  registry = PLATFORM_DELEGATE().default_resolver_registry()
109  registry.register(SimpleResolver)
110  registry.register(RecursiveResolver)
111  config = BatchTraceProcessorConfig(
112      load_failure_handling=load_failure_handling,
113      execute_failure_handling=execute_failure_handling,
114      tp_config=TraceProcessorConfig(
115          bin_path=os.environ["SHELL_PATH"], resolver_registry=registry))
116  return BatchTraceProcessor(traces=traces, config=config, observer=observer)
117
118
119def create_tp(trace: TraceReference):
120  return TraceProcessor(
121      trace=trace,
122      config=TraceProcessorConfig(bin_path=os.environ["SHELL_PATH"]))
123
124
125def example_android_trace_path():
126  return os.path.join(os.environ["ROOT_DIR"], 'test', 'data',
127                      'example_android_trace_30s.pb')
128
129
130class TestApi(unittest.TestCase):
131
132  def test_invalid_trace(self):
133    f = io.BytesIO(b'<foo></foo>')
134    with self.assertRaises(TraceProcessorException):
135      _ = create_tp(trace=f)
136
137  def test_runtime_error(self):
138    # We emulate a situation when TP returns an error by passing the --version
139    # flag. This makes TP output version information and exit, instead of
140    # starting an http server.
141    config = TraceProcessorConfig(
142        bin_path=os.environ["SHELL_PATH"], extra_flags=["--version"])
143    with self.assertRaisesRegex(
144        TraceProcessorException,
145        expected_regex='.*Trace Processor RPC API version:.*'):
146      TraceProcessor(trace=io.BytesIO(b''), config=config)
147
148  def test_trace_path(self):
149    # Get path to trace_processor_shell and construct TraceProcessor
150    tp = create_tp(trace=example_android_trace_path())
151    qr_iterator = tp.query('select * from slice limit 10')
152    dur_result = [
153        178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104, 275312,
154        65313
155    ]
156
157    for num, row in enumerate(qr_iterator):
158      self.assertEqual(row.type, '__intrinsic_slice')
159      self.assertEqual(row.dur, dur_result[num])
160
161    # Test the batching logic by issuing a large query and ensuring we receive
162    # all rows, not just a truncated subset.
163    qr_iterator = tp.query('select count(*) as cnt from slice')
164    expected_count = next(qr_iterator).cnt
165    self.assertGreater(expected_count, 0)
166
167    qr_iterator = tp.query('select * from slice')
168    count = sum(1 for _ in qr_iterator)
169    self.assertEqual(count, expected_count)
170
171    tp.close()
172
173  def test_trace_byteio(self):
174    f = io.BytesIO(
175        b'\n(\n&\x08\x00\x12\x12\x08\x01\x10\xc8\x01\x1a\x0b\x12\t'
176        b'B|200|foo\x12\x0e\x08\x02\x10\xc8\x01\x1a\x07\x12\x05E|200')
177    with create_tp(trace=f) as tp:
178      qr_iterator = tp.query('select * from slice limit 10')
179      res = list(qr_iterator)
180
181      self.assertEqual(len(res), 1)
182
183      row = res[0]
184      self.assertEqual(row.ts, 1)
185      self.assertEqual(row.dur, 1)
186      self.assertEqual(row.name, 'foo')
187
188  def test_trace_file(self):
189    with open(example_android_trace_path(), 'rb') as file:
190      with create_tp(trace=file) as tp:
191        qr_iterator = tp.query('select * from slice limit 10')
192        dur_result = [
193            178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104,
194            275312, 65313
195        ]
196
197        for num, row in enumerate(qr_iterator):
198          self.assertEqual(row.dur, dur_result[num])
199
200  def test_trace_generator(self):
201
202    def reader_generator():
203      with open(example_android_trace_path(), 'rb') as file:
204        yield file.read(1024)
205
206    with create_tp(trace=reader_generator()) as tp:
207      qr_iterator = tp.query('select * from slice limit 10')
208      dur_result = [
209          178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104,
210          275312, 65313
211      ]
212
213      for num, row in enumerate(qr_iterator):
214        self.assertEqual(row.dur, dur_result[num])
215
216  def test_simple_resolver(self):
217    dur = [178646, 178646, 178646, 178646]
218    source = ['generator', 'path', 'path_resolver', 'file']
219    expected = pd.DataFrame(list(zip(dur, source)), columns=['dur', 'source'])
220
221    with create_batch_tp(
222        traces='simple:path={}'.format(example_android_trace_path())) as btp:
223      df = btp.query_and_flatten('select dur from slice limit 1')
224      pd.testing.assert_frame_equal(df, expected, check_dtype=False)
225
226    with create_batch_tp(
227        traces=SimpleResolver(path=example_android_trace_path())) as btp:
228      df = btp.query_and_flatten('select dur from slice limit 1')
229      pd.testing.assert_frame_equal(df, expected, check_dtype=False)
230
231  def test_query_timing(self):
232    observer = SimpleObserver()
233    with create_batch_tp(
234        traces='simple:path={}'.format(example_android_trace_path()),
235        observer=observer) as btp:
236      btp.query_and_flatten('select dur from slice limit 1')
237      self.assertTrue(
238          all([x > 0 for x in observer.execution_times]),
239          'Running time should be positive')
240
241  def test_recursive_resolver(self):
242    dur = [
243        178646, 178646, 178646, 178646, 178646, 178646, 178646, 178646, 178646
244    ]
245    source = ['recursive_gen', 'generator', 'path', 'generator', 'path']
246    root_source = [
247        None, 'recursive_path', 'recursive_path', 'recursive_obj',
248        'recursive_obj'
249    ]
250    expected = pd.DataFrame(
251        list(zip(dur, source, root_source)),
252        columns=['dur', 'source', 'root_source'])
253
254    uri = 'recursive:path={};skip_resolve_file=true'.format(
255        example_android_trace_path())
256    with create_batch_tp(traces=uri) as btp:
257      df = btp.query_and_flatten('select dur from slice limit 1')
258      pd.testing.assert_frame_equal(df, expected, check_dtype=False)
259
260    with create_batch_tp(
261        traces=RecursiveResolver(
262            path=example_android_trace_path(), skip_resolve_file=True)) as btp:
263      df = btp.query_and_flatten('select dur from slice limit 1')
264      pd.testing.assert_frame_equal(df, expected, check_dtype=False)
265
266  def test_btp_load_failure(self):
267    f = io.BytesIO(b'<foo></foo>')
268    with self.assertRaises(TraceProcessorException):
269      _ = create_batch_tp(traces=f)
270
271  def test_btp_load_failure_increment_stat(self):
272    f = io.BytesIO(b'<foo></foo>')
273    btp = create_batch_tp(
274        traces=f, load_failure_handling=FailureHandling.INCREMENT_STAT)
275    self.assertEqual(btp.stats().load_failures, 1)
276
277  def test_btp_query_failure(self):
278    btp = create_batch_tp(traces=example_android_trace_path())
279    with self.assertRaises(TraceProcessorException):
280      _ = btp.query('select * from sl')
281
282  def test_btp_query_failure_increment_stat(self):
283    btp = create_batch_tp(
284        traces=example_android_trace_path(),
285        execute_failure_handling=FailureHandling.INCREMENT_STAT)
286    _ = btp.query('select * from sl')
287    self.assertEqual(btp.stats().execute_failures, 1)
288
289  def test_btp_query_failure_message(self):
290    btp = create_batch_tp(
291        traces='simple:path={}'.format(example_android_trace_path()))
292    with self.assertRaisesRegex(
293        TraceProcessorException, expected_regex='.*source.*generator.*'):
294      _ = btp.query('select * from sl')
295
296  def test_extra_flags(self):
297    with tempfile.TemporaryDirectory() as temp_dir:
298      test_module_dir = os.path.join(temp_dir, 'ext')
299      os.makedirs(test_module_dir)
300      test_module = os.path.join(test_module_dir, 'module.sql')
301      with open(test_module, 'w') as f:
302        f.write('CREATE TABLE test_table AS SELECT 123 AS test_value\n')
303      config = TraceProcessorConfig(
304          bin_path=os.environ["SHELL_PATH"],
305          extra_flags=['--add-sql-module', test_module_dir])
306      with TraceProcessor(trace=io.BytesIO(b''), config=config) as tp:
307        qr_iterator = tp.query(
308            'SELECT IMPORT("ext.module"); SELECT test_value FROM test_table')
309        self.assertEqual(next(qr_iterator).test_value, 123)
310