1// Copyright 2021 The Pigweed Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may not 4// use this file except in compliance with the License. You may obtain a copy of 5// the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12// License for the specific language governing permissions and limitations under 13// the License. 14 15import { Status } from 'pigweedjs/pw_status'; 16import { Message } from 'google-protobuf'; 17 18import WaitQueue from './queue'; 19 20import { PendingCalls, Rpc } from './rpc_classes'; 21 22export type Callback = (a: any) => any; 23 24class RpcError extends Error { 25 status: Status; 26 27 constructor(rpc: Rpc, status: Status) { 28 let message = ''; 29 if (status === Status.NOT_FOUND) { 30 message = ': the RPC server does not support this RPC'; 31 } else if (status === Status.DATA_LOSS) { 32 message = ': an error occurred while decoding the RPC payload'; 33 } 34 35 super(`${rpc.method.name} failed with error ${Status[status]}${message}`); 36 this.status = status; 37 } 38} 39 40class RpcTimeout extends Error { 41 readonly rpc: Rpc; 42 readonly timeoutMs: number; 43 44 constructor(rpc: Rpc, timeoutMs: number) { 45 super(`${rpc.method.name} timed out after ${timeoutMs} ms`); 46 this.rpc = rpc; 47 this.timeoutMs = timeoutMs; 48 } 49} 50 51class Responses { 52 private responses: Message[] = []; 53 private totalResponses = 0; 54 private readonly maxResponses: number; 55 56 constructor(maxResponses: number) { 57 this.maxResponses = maxResponses; 58 } 59 60 get length(): number { 61 return Math.min(this.totalResponses, this.maxResponses); 62 } 63 64 push(response: Message): void { 65 this.responses[this.totalResponses % this.maxResponses] = response; 66 this.totalResponses += 1; 67 } 68 69 last(): Message | undefined { 70 if (this.totalResponses === 0) { 71 return undefined; 72 } 73 74 const lastIndex = (this.totalResponses - 1) % this.maxResponses; 75 return this.responses[lastIndex]; 76 } 77 78 getAll(): Message[] { 79 if (this.totalResponses < this.maxResponses) { 80 return this.responses.slice(0, this.totalResponses); 81 } 82 83 const splitIndex = this.totalResponses % this.maxResponses; 84 return this.responses 85 .slice(splitIndex) 86 .concat(this.responses.slice(0, splitIndex)); 87 } 88} 89 90/** Represent an in-progress or completed RPC call. */ 91export class Call { 92 // Responses ordered by arrival time. Undefined signifies stream completion. 93 private responseQueue = new WaitQueue<Message | undefined>(); 94 protected responses: Responses; 95 96 private rpcs: PendingCalls; 97 rpc: Rpc; 98 readonly callId: number; 99 100 private onNext: Callback; 101 private onCompleted: Callback; 102 private onError: Callback; 103 104 status?: Status; 105 error?: Status; 106 callbackException?: Error; 107 108 constructor( 109 rpcs: PendingCalls, 110 rpc: Rpc, 111 onNext: Callback, 112 onCompleted: Callback, 113 onError: Callback, 114 maxResponses: number, 115 ) { 116 this.rpcs = rpcs; 117 this.rpc = rpc; 118 this.responses = new Responses(maxResponses); 119 120 this.onNext = onNext; 121 this.onCompleted = onCompleted; 122 this.onError = onError; 123 this.callId = rpcs.allocateCallId(); 124 } 125 126 /* Calls the RPC. This must be called immediately after construction. */ 127 invoke(request?: Message, ignoreErrors = false): void { 128 const previous = this.rpcs.sendRequest( 129 this.rpc, 130 this, 131 ignoreErrors, 132 request, 133 ); 134 135 if (previous !== undefined && !previous.completed) { 136 previous.handleError(Status.CANCELLED); 137 } 138 } 139 140 get completed(): boolean { 141 return this.status !== undefined || this.error !== undefined; 142 } 143 144 // eslint-disable-next-line @typescript-eslint/ban-types 145 private invokeCallback(func: () => {}) { 146 try { 147 func(); 148 } catch (err: unknown) { 149 if (err instanceof Error) { 150 console.error( 151 `An exception was raised while invoking a callback: ${err}`, 152 ); 153 this.callbackException = err; 154 } 155 console.error(`Unexpected item thrown while invoking callback: ${err}`); 156 } 157 } 158 159 handleResponse(response: Message): void { 160 this.responses.push(response); 161 this.responseQueue.push(response); 162 this.invokeCallback(() => this.onNext(response)); 163 } 164 165 handleCompletion(status: Status) { 166 this.status = status; 167 this.responseQueue.push(undefined); 168 this.invokeCallback(() => this.onCompleted(status)); 169 } 170 171 handleError(error: Status): void { 172 this.error = error; 173 this.responseQueue.push(undefined); 174 this.invokeCallback(() => this.onError(error)); 175 } 176 177 private async queuePopWithTimeout( 178 timeoutMs: number, 179 ): Promise<Message | undefined> { 180 // eslint-disable-next-line no-async-promise-executor 181 return new Promise(async (resolve, reject) => { 182 let timeoutExpired = false; 183 const timeoutWatcher = setTimeout(() => { 184 timeoutExpired = true; 185 reject(new RpcTimeout(this.rpc, timeoutMs)); 186 }, timeoutMs); 187 const response = await this.responseQueue.shift(); 188 if (timeoutExpired) { 189 this.responseQueue.unshift(response); 190 return; 191 } 192 clearTimeout(timeoutWatcher); 193 resolve(response); 194 }); 195 } 196 197 /** 198 * Yields responses up the specified count as they are added. 199 * 200 * Throws an error as soon as it is received even if there are still 201 * responses in the queue. 202 * 203 * Usage 204 * ``` 205 * for await (const response of call.getResponses(5)) { 206 * console.log(response); 207 * } 208 * ``` 209 * 210 * @param {number} count The number of responses to read before returning. 211 * If no value is specified, getResponses will block until the stream 212 * either ends or hits an error. 213 * @param {number} timeout The number of milliseconds to wait for a response 214 * before throwing an error. 215 */ 216 async *getResponses( 217 count?: number, 218 timeoutMs?: number, 219 ): AsyncGenerator<Message> { 220 this.checkErrors(); 221 222 if (this.completed && this.responseQueue.length == 0) { 223 return; 224 } 225 226 let remaining = count ?? Number.POSITIVE_INFINITY; 227 while (remaining > 0) { 228 const response = 229 timeoutMs === undefined 230 ? await this.responseQueue.shift() 231 : await this.queuePopWithTimeout(timeoutMs!); 232 this.checkErrors(); 233 if (response === undefined) { 234 return; 235 } 236 yield response!; 237 remaining -= 1; 238 } 239 } 240 241 cancel(): boolean { 242 if (this.completed) { 243 return false; 244 } 245 246 this.error = Status.CANCELLED; 247 return this.rpcs.sendCancel(this.rpc, this.callId); 248 } 249 250 private checkErrors(): void { 251 if (this.callbackException !== undefined) { 252 throw this.callbackException; 253 } 254 if (this.error !== undefined) { 255 throw new RpcError(this.rpc, this.error); 256 } 257 } 258 259 protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> { 260 for await (const response of this.getResponses(1, timeoutMs)) { 261 // Do nothing. 262 } 263 if (this.status === undefined) { 264 throw Error('Unexpected undefined status at end of stream'); 265 } 266 if (this.responses.length !== 1) { 267 throw Error(`Unexpected number of responses: ${this.responses.length}`); 268 } 269 return [this.status!, this.responses.last()!]; 270 } 271 272 protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> { 273 for await (const response of this.getResponses(undefined, timeoutMs)) { 274 // Do nothing. 275 } 276 if (this.status === undefined) { 277 throw Error('Unexpected undefined status at end of stream'); 278 } 279 return [this.status!, this.responses.getAll()]; 280 } 281 282 protected sendClientStream(request: Message) { 283 this.checkErrors(); 284 if (this.status !== undefined) { 285 throw new RpcError(this.rpc, Status.FAILED_PRECONDITION); 286 } 287 this.rpcs.sendClientStream(this.rpc, request, this.callId); 288 } 289 290 protected finishClientStream(requests: Message[]) { 291 for (const request of requests) { 292 this.sendClientStream(request); 293 } 294 295 if (!this.completed) { 296 this.rpcs.sendClientStreamEnd(this.rpc, this.callId); 297 } 298 } 299} 300 301/** Tracks the state of a unary RPC call. */ 302export class UnaryCall extends Call { 303 /** Awaits the server response */ 304 async complete(timeoutMs?: number): Promise<[Status, Message]> { 305 return await this.unaryWait(timeoutMs); 306 } 307} 308 309/** Tracks the state of a client streaming RPC call. */ 310export class ClientStreamingCall extends Call { 311 /** Gets the last server message, if it exists */ 312 get response(): Message | undefined { 313 return this.responses.last(); 314 } 315 316 /** Sends a message from the client. */ 317 send(request: Message) { 318 this.sendClientStream(request); 319 } 320 321 /** Ends the client stream and waits for the RPC to complete. */ 322 async finishAndWait( 323 requests: Message[] = [], 324 timeoutMs?: number, 325 ): Promise<[Status, Message]> { 326 this.finishClientStream(requests); 327 return await this.unaryWait(timeoutMs); 328 } 329} 330 331/** Tracks the state of a server streaming RPC call. */ 332export class ServerStreamingCall extends Call { 333 complete(timeoutMs?: number): Promise<[Status, Message[]]> { 334 return this.streamWait(timeoutMs); 335 } 336} 337 338/** Tracks the state of a bidirectional streaming RPC call. */ 339export class BidirectionalStreamingCall extends Call { 340 /** Sends a message from the client. */ 341 send(request: Message) { 342 this.sendClientStream(request); 343 } 344 345 /** Ends the client stream and waits for the RPC to complete. */ 346 async finishAndWait( 347 requests: Array<Message> = [], 348 timeoutMs?: number, 349 ): Promise<[Status, Array<Message>]> { 350 this.finishClientStream(requests); 351 return await this.streamWait(timeoutMs); 352 } 353} 354