1// Copyright (C) 2022 The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15import protobuf from 'protobufjs/minimal';
16import {defer, Deferred} from '../../../base/deferred';
17import {assertExists, assertFalse, assertTrue} from '../../../base/logging';
18import {
19  DisableTracingRequest,
20  DisableTracingResponse,
21  EnableTracingRequest,
22  EnableTracingResponse,
23  FreeBuffersRequest,
24  FreeBuffersResponse,
25  GetTraceStatsRequest,
26  GetTraceStatsResponse,
27  IBufferStats,
28  IMethodInfo,
29  IPCFrame,
30  ISlice,
31  QueryServiceStateRequest,
32  QueryServiceStateResponse,
33  ReadBuffersRequest,
34  ReadBuffersResponse,
35  TraceConfig,
36} from '../protos';
37import {RecordingError} from './recording_error_handling';
38import {
39  ByteStream,
40  DataSource,
41  TracingSession,
42  TracingSessionListener,
43} from './recording_interfaces_v2';
44import {
45  BUFFER_USAGE_INCORRECT_FORMAT,
46  BUFFER_USAGE_NOT_ACCESSIBLE,
47  PARSING_UNABLE_TO_DECODE_METHOD,
48  PARSING_UNKNWON_REQUEST_ID,
49  PARSING_UNRECOGNIZED_MESSAGE,
50  PARSING_UNRECOGNIZED_PORT,
51  RECORDING_IN_PROGRESS,
52} from './recording_utils';
53import {exists} from '../../../base/utils';
54
55// See wire_protocol.proto for more details.
56const WIRE_PROTOCOL_HEADER_SIZE = 4;
57// See basic_types.h (kIPCBufferSize) for more details.
58const MAX_IPC_BUFFER_SIZE = 128 * 1024;
59
60const PROTO_LEN_DELIMITED_WIRE_TYPE = 2;
61const TRACE_PACKET_PROTO_ID = 1;
62const TRACE_PACKET_PROTO_TAG =
63  (TRACE_PACKET_PROTO_ID << 3) | PROTO_LEN_DELIMITED_WIRE_TYPE;
64
65function parseMessageSize(buffer: Uint8Array) {
66  const dv = new DataView(buffer.buffer, buffer.byteOffset, buffer.length);
67  return dv.getUint32(0, true);
68}
69
70// This class implements the protocol described in
71// https://perfetto.dev/docs/design-docs/api-and-abi#tracing-protocol-abi
72export class TracedTracingSession implements TracingSession {
73  // Buffers received wire protocol data.
74  private incomingBuffer = new Uint8Array(MAX_IPC_BUFFER_SIZE);
75  private bufferedPartLength = 0;
76  private currentFrameLength?: number;
77
78  private availableMethods: IMethodInfo[] = [];
79  private serviceId = -1;
80
81  private resolveBindingPromise!: Deferred<void>;
82  private requestMethods = new Map<number, string>();
83
84  // Needed for ReadBufferResponse: all the trace packets are split into
85  // several slices. |partialPacket| is the buffer for them. Once we receive a
86  // slice with the flag |lastSliceForPacket|, a new packet is created.
87  private partialPacket: ISlice[] = [];
88  // Accumulates trace packets into a proto trace file..
89  private traceProtoWriter = protobuf.Writer.create();
90
91  // Accumulates DataSource objects from QueryServiceStateResponse,
92  // which can have >1 replies for each query
93  // go/codesearch/android/external/perfetto/protos/
94  // perfetto/ipc/consumer_port.proto;l=243-246
95  private pendingDataSources: DataSource[] = [];
96
97  // For concurrent calls to 'QueryServiceState', we return the same value.
98  private pendingQssMessage?: Deferred<DataSource[]>;
99
100  // Wire protocol request ID. After each request it is increased. It is needed
101  // to keep track of the type of request, and parse the response correctly.
102  private requestId = 1;
103
104  private pendingStatsMessages = new Array<Deferred<IBufferStats[]>>();
105
106  // The bytestream is obtained when creating a connection with a target.
107  // For instance, the AdbStream is obtained from a connection with an Adb
108  // device.
109  constructor(
110    private byteStream: ByteStream,
111    private tracingSessionListener: TracingSessionListener,
112  ) {
113    this.byteStream.addOnStreamDataCallback((data) =>
114      this.handleReceivedData(data),
115    );
116    this.byteStream.addOnStreamCloseCallback(() => this.clearState());
117  }
118
119  queryServiceState(): Promise<DataSource[]> {
120    if (this.pendingQssMessage) {
121      return this.pendingQssMessage;
122    }
123
124    const requestProto = QueryServiceStateRequest.encode(
125      new QueryServiceStateRequest(),
126    ).finish();
127    this.rpcInvoke('QueryServiceState', requestProto);
128
129    return (this.pendingQssMessage = defer<DataSource[]>());
130  }
131
132  start(config: TraceConfig): void {
133    const duration = config.durationMs;
134    this.tracingSessionListener.onStatus(
135      `${RECORDING_IN_PROGRESS}${
136        duration ? ' for ' + duration.toString() + ' ms' : ''
137      }...`,
138    );
139
140    const enableTracingRequest = new EnableTracingRequest();
141    enableTracingRequest.traceConfig = config;
142    const enableTracingRequestProto =
143      EnableTracingRequest.encode(enableTracingRequest).finish();
144    this.rpcInvoke('EnableTracing', enableTracingRequestProto);
145  }
146
147  cancel(): void {
148    this.terminateConnection();
149  }
150
151  stop(): void {
152    const requestProto = DisableTracingRequest.encode(
153      new DisableTracingRequest(),
154    ).finish();
155    this.rpcInvoke('DisableTracing', requestProto);
156  }
157
158  async getTraceBufferUsage(): Promise<number> {
159    if (!this.byteStream.isConnected()) {
160      // TODO(octaviant): make this more in line with the other trace buffer
161      //  error cases.
162      return 0;
163    }
164    const bufferStats = await this.getBufferStats();
165    let percentageUsed = -1;
166    for (const buffer of bufferStats) {
167      if (
168        !Number.isFinite(buffer.bytesWritten) ||
169        !Number.isFinite(buffer.bufferSize)
170      ) {
171        continue;
172      }
173      const used = assertExists(buffer.bytesWritten);
174      const total = assertExists(buffer.bufferSize);
175      if (total >= 0) {
176        percentageUsed = Math.max(percentageUsed, used / total);
177      }
178    }
179
180    if (percentageUsed === -1) {
181      return Promise.reject(new RecordingError(BUFFER_USAGE_INCORRECT_FORMAT));
182    }
183    return percentageUsed;
184  }
185
186  initConnection(): Promise<void> {
187    // bind IPC methods
188    const requestId = this.requestId++;
189    const frame = new IPCFrame({
190      requestId,
191      msgBindService: new IPCFrame.BindService({serviceName: 'ConsumerPort'}),
192    });
193    this.writeFrame(frame);
194
195    // We shouldn't bind multiple times to the service in the same tracing
196    // session.
197    // eslint-disable-next-line @typescript-eslint/strict-boolean-expressions
198    assertFalse(!!this.resolveBindingPromise);
199    this.resolveBindingPromise = defer<void>();
200    return this.resolveBindingPromise;
201  }
202
203  private getBufferStats(): Promise<IBufferStats[]> {
204    const getTraceStatsRequestProto = GetTraceStatsRequest.encode(
205      new GetTraceStatsRequest(),
206    ).finish();
207    try {
208      this.rpcInvoke('GetTraceStats', getTraceStatsRequestProto);
209    } catch (e) {
210      // GetTraceStats was introduced only on Android 10.
211      this.raiseError(e);
212    }
213
214    const statsMessage = defer<IBufferStats[]>();
215    this.pendingStatsMessages.push(statsMessage);
216    return statsMessage;
217  }
218
219  private terminateConnection(): void {
220    this.clearState();
221    const requestProto = FreeBuffersRequest.encode(
222      new FreeBuffersRequest(),
223    ).finish();
224    this.rpcInvoke('FreeBuffers', requestProto);
225    this.byteStream.close();
226  }
227
228  private clearState() {
229    for (const statsMessage of this.pendingStatsMessages) {
230      statsMessage.reject(new RecordingError(BUFFER_USAGE_NOT_ACCESSIBLE));
231    }
232    this.pendingStatsMessages = [];
233    this.pendingDataSources = [];
234    this.pendingQssMessage = undefined;
235  }
236
237  private rpcInvoke(methodName: string, argsProto: Uint8Array): void {
238    if (!this.byteStream.isConnected()) {
239      return;
240    }
241    const method = this.availableMethods.find((m) => m.name === methodName);
242    if (!exists(method) || !exists(method.id)) {
243      throw new RecordingError(
244        `Method ${methodName} not supported by the target`,
245      );
246    }
247    const requestId = this.requestId++;
248    const frame = new IPCFrame({
249      requestId,
250      msgInvokeMethod: new IPCFrame.InvokeMethod({
251        serviceId: this.serviceId,
252        methodId: method.id,
253        argsProto,
254      }),
255    });
256    this.requestMethods.set(requestId, methodName);
257    this.writeFrame(frame);
258  }
259
260  private writeFrame(frame: IPCFrame): void {
261    const frameProto: Uint8Array = IPCFrame.encode(frame).finish();
262    const frameLen = frameProto.length;
263    const buf = new Uint8Array(WIRE_PROTOCOL_HEADER_SIZE + frameLen);
264    const dv = new DataView(buf.buffer);
265    dv.setUint32(0, frameProto.length, /* littleEndian */ true);
266    for (let i = 0; i < frameLen; i++) {
267      dv.setUint8(WIRE_PROTOCOL_HEADER_SIZE + i, frameProto[i]);
268    }
269    this.byteStream.write(buf);
270  }
271
272  private handleReceivedData(rawData: Uint8Array): void {
273    // we parse the length of the next frame if it's available
274    if (
275      this.currentFrameLength === undefined &&
276      this.canCompleteLengthHeader(rawData)
277    ) {
278      const remainingFrameBytes =
279        WIRE_PROTOCOL_HEADER_SIZE - this.bufferedPartLength;
280      this.appendToIncomingBuffer(rawData.subarray(0, remainingFrameBytes));
281      rawData = rawData.subarray(remainingFrameBytes);
282
283      this.currentFrameLength = parseMessageSize(this.incomingBuffer);
284      this.bufferedPartLength = 0;
285    }
286
287    // Parse all complete frames.
288    while (
289      this.currentFrameLength !== undefined &&
290      this.bufferedPartLength + rawData.length >= this.currentFrameLength
291    ) {
292      // Read the remaining part of this message.
293      const bytesToCompleteMessage =
294        this.currentFrameLength - this.bufferedPartLength;
295      this.appendToIncomingBuffer(rawData.subarray(0, bytesToCompleteMessage));
296      this.parseFrame(this.incomingBuffer.subarray(0, this.currentFrameLength));
297      this.bufferedPartLength = 0;
298      // Remove the data just parsed.
299      rawData = rawData.subarray(bytesToCompleteMessage);
300
301      if (!this.canCompleteLengthHeader(rawData)) {
302        this.currentFrameLength = undefined;
303        break;
304      }
305      this.currentFrameLength = parseMessageSize(rawData);
306      rawData = rawData.subarray(WIRE_PROTOCOL_HEADER_SIZE);
307    }
308
309    // Buffer the remaining data (part of the next message).
310    this.appendToIncomingBuffer(rawData);
311  }
312
313  private canCompleteLengthHeader(newData: Uint8Array): boolean {
314    return newData.length + this.bufferedPartLength > WIRE_PROTOCOL_HEADER_SIZE;
315  }
316
317  private appendToIncomingBuffer(array: Uint8Array): void {
318    this.incomingBuffer.set(array, this.bufferedPartLength);
319    this.bufferedPartLength += array.length;
320  }
321
322  private parseFrame(frameBuffer: Uint8Array): void {
323    // Get a copy of the ArrayBuffer to avoid the original being overriden.
324    // See 170256902#comment21
325    const frame = IPCFrame.decode(frameBuffer.slice());
326    if (frame.msg === 'msgBindServiceReply') {
327      const msgBindServiceReply = frame.msgBindServiceReply;
328      if (
329        exists(msgBindServiceReply) &&
330        exists(msgBindServiceReply.methods) &&
331        exists(msgBindServiceReply.serviceId)
332      ) {
333        assertTrue(msgBindServiceReply.success === true);
334        this.availableMethods = msgBindServiceReply.methods;
335        this.serviceId = msgBindServiceReply.serviceId;
336        this.resolveBindingPromise.resolve();
337      }
338    } else if (frame.msg === 'msgInvokeMethodReply') {
339      const msgInvokeMethodReply = frame.msgInvokeMethodReply;
340      // We process messages without a `replyProto` field (for instance
341      // `FreeBuffers` does not have `replyProto`). However, we ignore messages
342      // without a valid 'success' field.
343      if (msgInvokeMethodReply?.success !== true) {
344        return;
345      }
346
347      const method = this.requestMethods.get(frame.requestId);
348      if (!method) {
349        this.raiseError(`${PARSING_UNKNWON_REQUEST_ID}: ${frame.requestId}`);
350        return;
351      }
352      const decoder = decoders.get(method);
353      if (decoder === undefined) {
354        this.raiseError(`${PARSING_UNABLE_TO_DECODE_METHOD}: ${method}`);
355        return;
356      }
357      const data = {...decoder(msgInvokeMethodReply.replyProto)};
358
359      if (method === 'ReadBuffers') {
360        for (const slice of data.slices ?? []) {
361          this.partialPacket.push(slice);
362          if (slice.lastSliceForPacket === true) {
363            let bufferSize = 0;
364            for (const slice of this.partialPacket) {
365              bufferSize += slice.data!.length;
366            }
367            const tracePacket = new Uint8Array(bufferSize);
368            let written = 0;
369            for (const slice of this.partialPacket) {
370              const data = slice.data!;
371              tracePacket.set(data, written);
372              written += data.length;
373            }
374            this.traceProtoWriter.uint32(TRACE_PACKET_PROTO_TAG);
375            this.traceProtoWriter.bytes(tracePacket);
376            this.partialPacket = [];
377          }
378        }
379        if (msgInvokeMethodReply.hasMore === false) {
380          this.tracingSessionListener.onTraceData(
381            this.traceProtoWriter.finish(),
382          );
383          this.terminateConnection();
384        }
385      } else if (method === 'EnableTracing') {
386        const readBuffersRequestProto = ReadBuffersRequest.encode(
387          new ReadBuffersRequest(),
388        ).finish();
389        this.rpcInvoke('ReadBuffers', readBuffersRequestProto);
390      } else if (method === 'GetTraceStats') {
391        const maybePendingStatsMessage = this.pendingStatsMessages.shift();
392        if (maybePendingStatsMessage) {
393          maybePendingStatsMessage.resolve(data?.traceStats?.bufferStats ?? []);
394        }
395      } else if (method === 'FreeBuffers') {
396        // No action required. If we successfully read a whole trace,
397        // we close the connection. Alternatively, if the tracing finishes
398        // with an exception or if the user cancels it, we also close the
399        // connection.
400      } else if (method === 'DisableTracing') {
401        // No action required. Same reasoning as for FreeBuffers.
402      } else if (method === 'QueryServiceState') {
403        const dataSources =
404          (data as QueryServiceStateResponse)?.serviceState?.dataSources || [];
405        for (const dataSource of dataSources) {
406          const name = dataSource?.dsDescriptor?.name;
407          if (name) {
408            this.pendingDataSources.push({
409              name,
410              descriptor: dataSource.dsDescriptor,
411            });
412          }
413        }
414        if (msgInvokeMethodReply.hasMore === false) {
415          assertExists(this.pendingQssMessage).resolve(this.pendingDataSources);
416          this.pendingDataSources = [];
417          this.pendingQssMessage = undefined;
418        }
419      } else {
420        this.raiseError(`${PARSING_UNRECOGNIZED_PORT}: ${method}`);
421      }
422    } else {
423      this.raiseError(`${PARSING_UNRECOGNIZED_MESSAGE}: ${frame.msg}`);
424    }
425  }
426
427  private raiseError(message: string): void {
428    this.terminateConnection();
429    this.tracingSessionListener.onError(message);
430  }
431}
432
433const decoders = new Map<string, Function>()
434  .set('EnableTracing', EnableTracingResponse.decode)
435  .set('FreeBuffers', FreeBuffersResponse.decode)
436  .set('ReadBuffers', ReadBuffersResponse.decode)
437  .set('DisableTracing', DisableTracingResponse.decode)
438  .set('GetTraceStats', GetTraceStatsResponse.decode)
439  .set('QueryServiceState', QueryServiceStateResponse.decode);
440