xref: /aosp_15_r20/external/pigweed/pw_rpc/ts/client.ts (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1// Copyright 2022 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15/** Provides a pw_rpc client for TypeScript. */
16
17import { ProtoCollection } from 'pigweedjs/pw_protobuf_compiler';
18import { Status } from 'pigweedjs/pw_status';
19import { Message } from 'google-protobuf';
20import {
21  PacketType,
22  RpcPacket,
23} from 'pigweedjs/protos/pw_rpc/internal/packet_pb';
24
25import { Channel, Service } from './descriptors';
26import { MethodStub, methodStubFactory } from './method';
27import * as packets from './packets';
28import { PendingCalls, Rpc } from './rpc_classes';
29
30/**
31 * Object for managing RPC service and contained methods.
32 */
33export class ServiceClient {
34  private service: Service;
35  private methods: MethodStub[] = [];
36  methodsByName = new Map<string, MethodStub>();
37
38  constructor(client: Client, channel: Channel, service: Service) {
39    this.service = service;
40    const methods = service.methods;
41    methods.forEach((method) => {
42      const stub = methodStubFactory(client.rpcs, channel, method);
43      this.methods.push(stub);
44      this.methodsByName.set(method.name, stub);
45    });
46  }
47
48  method(methodName: string): MethodStub | undefined {
49    return this.methodsByName.get(methodName);
50  }
51
52  get id(): number {
53    return this.service.id;
54  }
55
56  get name(): string {
57    return this.service.name;
58  }
59}
60
61/**
62 * Object for managing RPC channel and contained services.
63 */
64export class ChannelClient {
65  readonly channel: Channel;
66  services = new Map<string, ServiceClient>();
67
68  constructor(client: Client, channel: Channel, services: Service[]) {
69    this.channel = channel;
70    services.forEach((service) => {
71      const serviceClient = new ServiceClient(client, this.channel, service);
72      this.services.set(service.name, serviceClient);
73    });
74  }
75
76  /**
77   * Find a service client via its full name.
78   *
79   * For example:
80   * `service = client.channel().service('the.package.FooService');`
81   */
82  service(serviceName: string): ServiceClient | undefined {
83    return this.services.get(serviceName);
84  }
85
86  /**
87   * Find a method stub via its full name.
88   *
89   * For example:
90   * `method = client.channel().methodStub('the.package.AService.AMethod');`
91   */
92  methodStub(name: string): MethodStub | undefined {
93    const index = name.lastIndexOf('.');
94    if (index <= 0) {
95      console.error(`Malformed method name: ${name}`);
96      return undefined;
97    }
98    const serviceName = name.slice(0, index);
99    const methodName = name.slice(index + 1);
100    const method = this.service(serviceName)?.method(methodName);
101    if (method === undefined) {
102      console.error(`Method not found: ${name}`);
103      return undefined;
104    }
105    return method;
106  }
107}
108
109/**
110 * RPCs are invoked through a MethodStub. These can be found by name via
111 * methodStub(string name).
112 *
113 * ```
114 * method = client.channel(1).methodStub('the.package.FooService.SomeMethod')
115 * call = method.invoke(request);
116 * ```
117 */
118export class Client {
119  private channelsById = new Map<number, ChannelClient>();
120  readonly rpcs: PendingCalls;
121  readonly services = new Map<number, Service>();
122
123  constructor(channels: Channel[], services: Service[]) {
124    this.rpcs = new PendingCalls();
125    services.forEach((service) => {
126      this.services.set(service.id, service);
127    });
128
129    channels.forEach((channel) => {
130      this.channelsById.set(
131        channel.id,
132        new ChannelClient(this, channel, services),
133      );
134    });
135  }
136
137  /**
138   * Creates a client from a set of Channels and a library of Protos.
139   *
140   * @param {Channel[]} channels List of possible channels to use.
141   * @param {ProtoCollection} protoSet ProtoCollection containing protos
142   *     defining RPC services
143   * and methods.
144   */
145  static fromProtoSet(channels: Channel[], protoSet: ProtoCollection): Client {
146    let services: Service[] = [];
147    const descriptors = protoSet.fileDescriptorSet.getFileList();
148    descriptors.forEach((fileDescriptor) => {
149      const packageName = fileDescriptor.getPackage()!;
150      fileDescriptor.getServiceList().forEach((serviceDescriptor) => {
151        services = services.concat(
152          Service.fromProtoDescriptor(serviceDescriptor, protoSet, packageName),
153        );
154      });
155    });
156
157    return new Client(channels, services);
158  }
159
160  /**
161   * Finds the channel with the provided id. Returns undefined if there are no
162   * channels or no channel with a matching id.
163   *
164   * @param {number?} id If no id is specified, returns the first channel.
165   */
166  channel(id?: number): ChannelClient | undefined {
167    if (id === undefined) {
168      return this.channelsById.values().next().value;
169    }
170    return this.channelsById.get(id);
171  }
172
173  /**
174   * Creates a new RPC object holding channel, method, and service info.
175   * Returns undefined if the service or method does not exist.
176   */
177  private rpc(
178    packet: RpcPacket,
179    channelClient: ChannelClient,
180  ): Rpc | undefined {
181    const service = this.services.get(packet.getServiceId());
182    if (service == undefined) {
183      return undefined;
184    }
185    const method = service.methods.get(packet.getMethodId());
186    if (method == undefined) {
187      return undefined;
188    }
189    return new Rpc(channelClient.channel, service, method);
190  }
191
192  private decodeStatus(rpc: Rpc, packet: RpcPacket): Status | undefined {
193    if (packet.getType() === PacketType.SERVER_STREAM) {
194      return;
195    }
196    return packet.getStatus();
197  }
198
199  private decodePayload(rpc: Rpc, packet: RpcPacket): Message | undefined {
200    if (packet.getType() === PacketType.SERVER_ERROR) {
201      return undefined;
202    }
203
204    if (
205      packet.getType() === PacketType.RESPONSE &&
206      rpc.method.serverStreaming
207    ) {
208      return undefined;
209    }
210
211    const payload = packet.getPayload_asU8();
212    return packets.decodePayload(
213      payload,
214      rpc.method.responseType,
215      rpc.method.customResponseSerializer,
216    );
217  }
218
219  private sendClientError(
220    client: ChannelClient,
221    packet: RpcPacket,
222    error: Status,
223  ) {
224    client.channel.send(packets.encodeClientError(packet, error));
225  }
226
227  /**
228   * Processes an incoming packet.
229   *
230   * @param {Uint8Array} rawPacketData binary data for a pw_rpc packet.
231   * @return {Status} The status of processing the packet.
232   *    - OK: the packet was processed by the client
233   *    - DATA_LOSS: the packet could not be decoded
234   *    - INVALID_ARGUMENT: the packet is for a server, not a client
235   *    - NOT_FOUND: the packet's channel ID is not known to this client
236   */
237  processPacket(rawPacketData: Uint8Array): Status {
238    let packet;
239    try {
240      packet = packets.decode(rawPacketData);
241    } catch (err) {
242      console.warn(`Failed to decode packet: ${err}`);
243      console.debug(`Raw packet: ${rawPacketData}`);
244      return Status.DATA_LOSS;
245    }
246
247    if (packets.forServer(packet)) {
248      return Status.INVALID_ARGUMENT;
249    }
250
251    const channelClient = this.channelsById.get(packet.getChannelId());
252    if (channelClient == undefined) {
253      console.warn(`Unrecognized channel ID: ${packet.getChannelId()}`);
254      return Status.NOT_FOUND;
255    }
256
257    const rpc = this.rpc(packet, channelClient);
258    if (rpc == undefined) {
259      this.sendClientError(channelClient, packet, Status.NOT_FOUND);
260      console.warn('rpc service/method not found');
261      return Status.OK;
262    }
263
264    if (
265      packet.getType() !== PacketType.RESPONSE &&
266      packet.getType() !== PacketType.SERVER_STREAM &&
267      packet.getType() !== PacketType.SERVER_ERROR
268    ) {
269      console.error(`${rpc}: Unexpected packet type ${packet.getType()}`);
270      console.debug(`Packet: ${packet}`);
271      return Status.OK;
272    }
273
274    let status = this.decodeStatus(rpc, packet);
275    let payload;
276    try {
277      payload = this.decodePayload(rpc, packet);
278    } catch (error) {
279      this.sendClientError(channelClient, packet, Status.DATA_LOSS);
280      console.warn(`Failed to decode response: ${error}`);
281      console.debug(`Raw payload: ${packet.getPayload()}`);
282
283      // Make this an error packet so the error handler is called.
284      packet.setType(PacketType.SERVER_ERROR);
285      status = Status.DATA_LOSS;
286    }
287
288    const call = this.rpcs.getPending(rpc, packet.getCallId(), status);
289    if (call === undefined) {
290      this.sendClientError(channelClient, packet, Status.FAILED_PRECONDITION);
291      console.debug(`Discarding response for ${rpc}, which is not pending`);
292      return Status.OK;
293    }
294
295    if (packet.getType() === PacketType.SERVER_ERROR) {
296      if (status === Status.OK) {
297        throw new Error('Unexpected OK status on SERVER_ERROR');
298      }
299      if (status === undefined) {
300        throw new Error('Missing status on SERVER_ERROR');
301      }
302      console.warn(`${rpc}: invocation failed with status: ${Status[status]}`);
303      call.handleError(status);
304      return Status.OK;
305    }
306
307    if (payload !== undefined) {
308      call.handleResponse(payload);
309    }
310    if (status !== undefined) {
311      call.handleCompletion(status);
312    }
313    return Status.OK;
314  }
315}
316