xref: /aosp_15_r20/external/pigweed/pw_rpc/ts/call.ts (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1// Copyright 2021 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15import { Status } from 'pigweedjs/pw_status';
16import { Message } from 'google-protobuf';
17
18import WaitQueue from './queue';
19
20import { PendingCalls, Rpc } from './rpc_classes';
21
22export type Callback = (a: any) => any;
23
24class RpcError extends Error {
25  status: Status;
26
27  constructor(rpc: Rpc, status: Status) {
28    let message = '';
29    if (status === Status.NOT_FOUND) {
30      message = ': the RPC server does not support this RPC';
31    } else if (status === Status.DATA_LOSS) {
32      message = ': an error occurred while decoding the RPC payload';
33    }
34
35    super(`${rpc.method.name} failed with error ${Status[status]}${message}`);
36    this.status = status;
37  }
38}
39
40class RpcTimeout extends Error {
41  readonly rpc: Rpc;
42  readonly timeoutMs: number;
43
44  constructor(rpc: Rpc, timeoutMs: number) {
45    super(`${rpc.method.name} timed out after ${timeoutMs} ms`);
46    this.rpc = rpc;
47    this.timeoutMs = timeoutMs;
48  }
49}
50
51class Responses {
52  private responses: Message[] = [];
53  private totalResponses = 0;
54  private readonly maxResponses: number;
55
56  constructor(maxResponses: number) {
57    this.maxResponses = maxResponses;
58  }
59
60  get length(): number {
61    return Math.min(this.totalResponses, this.maxResponses);
62  }
63
64  push(response: Message): void {
65    this.responses[this.totalResponses % this.maxResponses] = response;
66    this.totalResponses += 1;
67  }
68
69  last(): Message | undefined {
70    if (this.totalResponses === 0) {
71      return undefined;
72    }
73
74    const lastIndex = (this.totalResponses - 1) % this.maxResponses;
75    return this.responses[lastIndex];
76  }
77
78  getAll(): Message[] {
79    if (this.totalResponses < this.maxResponses) {
80      return this.responses.slice(0, this.totalResponses);
81    }
82
83    const splitIndex = this.totalResponses % this.maxResponses;
84    return this.responses
85      .slice(splitIndex)
86      .concat(this.responses.slice(0, splitIndex));
87  }
88}
89
90/** Represent an in-progress or completed RPC call. */
91export class Call {
92  // Responses ordered by arrival time. Undefined signifies stream completion.
93  private responseQueue = new WaitQueue<Message | undefined>();
94  protected responses: Responses;
95
96  private rpcs: PendingCalls;
97  rpc: Rpc;
98  readonly callId: number;
99
100  private onNext: Callback;
101  private onCompleted: Callback;
102  private onError: Callback;
103
104  status?: Status;
105  error?: Status;
106  callbackException?: Error;
107
108  constructor(
109    rpcs: PendingCalls,
110    rpc: Rpc,
111    onNext: Callback,
112    onCompleted: Callback,
113    onError: Callback,
114    maxResponses: number,
115  ) {
116    this.rpcs = rpcs;
117    this.rpc = rpc;
118    this.responses = new Responses(maxResponses);
119
120    this.onNext = onNext;
121    this.onCompleted = onCompleted;
122    this.onError = onError;
123    this.callId = rpcs.allocateCallId();
124  }
125
126  /* Calls the RPC. This must be called immediately after construction. */
127  invoke(request?: Message, ignoreErrors = false): void {
128    const previous = this.rpcs.sendRequest(
129      this.rpc,
130      this,
131      ignoreErrors,
132      request,
133    );
134
135    if (previous !== undefined && !previous.completed) {
136      previous.handleError(Status.CANCELLED);
137    }
138  }
139
140  get completed(): boolean {
141    return this.status !== undefined || this.error !== undefined;
142  }
143
144  // eslint-disable-next-line @typescript-eslint/ban-types
145  private invokeCallback(func: () => {}) {
146    try {
147      func();
148    } catch (err: unknown) {
149      if (err instanceof Error) {
150        console.error(
151          `An exception was raised while invoking a callback: ${err}`,
152        );
153        this.callbackException = err;
154      }
155      console.error(`Unexpected item thrown while invoking callback: ${err}`);
156    }
157  }
158
159  handleResponse(response: Message): void {
160    this.responses.push(response);
161    this.responseQueue.push(response);
162    this.invokeCallback(() => this.onNext(response));
163  }
164
165  handleCompletion(status: Status) {
166    this.status = status;
167    this.responseQueue.push(undefined);
168    this.invokeCallback(() => this.onCompleted(status));
169  }
170
171  handleError(error: Status): void {
172    this.error = error;
173    this.responseQueue.push(undefined);
174    this.invokeCallback(() => this.onError(error));
175  }
176
177  private async queuePopWithTimeout(
178    timeoutMs: number,
179  ): Promise<Message | undefined> {
180    // eslint-disable-next-line no-async-promise-executor
181    return new Promise(async (resolve, reject) => {
182      let timeoutExpired = false;
183      const timeoutWatcher = setTimeout(() => {
184        timeoutExpired = true;
185        reject(new RpcTimeout(this.rpc, timeoutMs));
186      }, timeoutMs);
187      const response = await this.responseQueue.shift();
188      if (timeoutExpired) {
189        this.responseQueue.unshift(response);
190        return;
191      }
192      clearTimeout(timeoutWatcher);
193      resolve(response);
194    });
195  }
196
197  /**
198   * Yields responses up the specified count as they are added.
199   *
200   * Throws an error as soon as it is received even if there are still
201   * responses in the queue.
202   *
203   * Usage
204   * ```
205   * for await (const response of call.getResponses(5)) {
206   *  console.log(response);
207   * }
208   * ```
209   *
210   * @param {number} count The number of responses to read before returning.
211   *    If no value is specified, getResponses will block until the stream
212   *    either ends or hits an error.
213   * @param {number} timeout The number of milliseconds to wait for a response
214   *    before throwing an error.
215   */
216  async *getResponses(
217    count?: number,
218    timeoutMs?: number,
219  ): AsyncGenerator<Message> {
220    this.checkErrors();
221
222    if (this.completed && this.responseQueue.length == 0) {
223      return;
224    }
225
226    let remaining = count ?? Number.POSITIVE_INFINITY;
227    while (remaining > 0) {
228      const response =
229        timeoutMs === undefined
230          ? await this.responseQueue.shift()
231          : await this.queuePopWithTimeout(timeoutMs!);
232      this.checkErrors();
233      if (response === undefined) {
234        return;
235      }
236      yield response!;
237      remaining -= 1;
238    }
239  }
240
241  cancel(): boolean {
242    if (this.completed) {
243      return false;
244    }
245
246    this.error = Status.CANCELLED;
247    return this.rpcs.sendCancel(this.rpc, this.callId);
248  }
249
250  private checkErrors(): void {
251    if (this.callbackException !== undefined) {
252      throw this.callbackException;
253    }
254    if (this.error !== undefined) {
255      throw new RpcError(this.rpc, this.error);
256    }
257  }
258
259  protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> {
260    for await (const response of this.getResponses(1, timeoutMs)) {
261      // Do nothing.
262    }
263    if (this.status === undefined) {
264      throw Error('Unexpected undefined status at end of stream');
265    }
266    if (this.responses.length !== 1) {
267      throw Error(`Unexpected number of responses: ${this.responses.length}`);
268    }
269    return [this.status!, this.responses.last()!];
270  }
271
272  protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> {
273    for await (const response of this.getResponses(undefined, timeoutMs)) {
274      // Do nothing.
275    }
276    if (this.status === undefined) {
277      throw Error('Unexpected undefined status at end of stream');
278    }
279    return [this.status!, this.responses.getAll()];
280  }
281
282  protected sendClientStream(request: Message) {
283    this.checkErrors();
284    if (this.status !== undefined) {
285      throw new RpcError(this.rpc, Status.FAILED_PRECONDITION);
286    }
287    this.rpcs.sendClientStream(this.rpc, request, this.callId);
288  }
289
290  protected finishClientStream(requests: Message[]) {
291    for (const request of requests) {
292      this.sendClientStream(request);
293    }
294
295    if (!this.completed) {
296      this.rpcs.sendClientStreamEnd(this.rpc, this.callId);
297    }
298  }
299}
300
301/** Tracks the state of a unary RPC call. */
302export class UnaryCall extends Call {
303  /** Awaits the server response */
304  async complete(timeoutMs?: number): Promise<[Status, Message]> {
305    return await this.unaryWait(timeoutMs);
306  }
307}
308
309/** Tracks the state of a client streaming RPC call. */
310export class ClientStreamingCall extends Call {
311  /** Gets the last server message, if it exists */
312  get response(): Message | undefined {
313    return this.responses.last();
314  }
315
316  /** Sends a message from the client. */
317  send(request: Message) {
318    this.sendClientStream(request);
319  }
320
321  /** Ends the client stream and waits for the RPC to complete. */
322  async finishAndWait(
323    requests: Message[] = [],
324    timeoutMs?: number,
325  ): Promise<[Status, Message]> {
326    this.finishClientStream(requests);
327    return await this.unaryWait(timeoutMs);
328  }
329}
330
331/** Tracks the state of a server streaming RPC call. */
332export class ServerStreamingCall extends Call {
333  complete(timeoutMs?: number): Promise<[Status, Message[]]> {
334    return this.streamWait(timeoutMs);
335  }
336}
337
338/** Tracks the state of a bidirectional streaming RPC call. */
339export class BidirectionalStreamingCall extends Call {
340  /** Sends a message from the client. */
341  send(request: Message) {
342    this.sendClientStream(request);
343  }
344
345  /** Ends the client stream and waits for the RPC to complete. */
346  async finishAndWait(
347    requests: Array<Message> = [],
348    timeoutMs?: number,
349  ): Promise<[Status, Array<Message>]> {
350    this.finishClientStream(requests);
351    return await this.streamWait(timeoutMs);
352  }
353}
354