1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonrpc
6
7import (
8	"encoding/json"
9	"errors"
10	"fmt"
11	"io"
12	"net"
13	"net/rpc"
14	"reflect"
15	"strings"
16	"testing"
17)
18
19type Args struct {
20	A, B int
21}
22
23type Reply struct {
24	C int
25}
26
27type Arith int
28
29type ArithAddResp struct {
30	Id     any   `json:"id"`
31	Result Reply `json:"result"`
32	Error  any   `json:"error"`
33}
34
35func (t *Arith) Add(args *Args, reply *Reply) error {
36	reply.C = args.A + args.B
37	return nil
38}
39
40func (t *Arith) Mul(args *Args, reply *Reply) error {
41	reply.C = args.A * args.B
42	return nil
43}
44
45func (t *Arith) Div(args *Args, reply *Reply) error {
46	if args.B == 0 {
47		return errors.New("divide by zero")
48	}
49	reply.C = args.A / args.B
50	return nil
51}
52
53func (t *Arith) Error(args *Args, reply *Reply) error {
54	panic("ERROR")
55}
56
57type BuiltinTypes struct{}
58
59func (BuiltinTypes) Map(i int, reply *map[int]int) error {
60	(*reply)[i] = i
61	return nil
62}
63
64func (BuiltinTypes) Slice(i int, reply *[]int) error {
65	*reply = append(*reply, i)
66	return nil
67}
68
69func (BuiltinTypes) Array(i int, reply *[1]int) error {
70	(*reply)[0] = i
71	return nil
72}
73
74func init() {
75	rpc.Register(new(Arith))
76	rpc.Register(BuiltinTypes{})
77}
78
79func TestServerNoParams(t *testing.T) {
80	cli, srv := net.Pipe()
81	defer cli.Close()
82	go ServeConn(srv)
83	dec := json.NewDecoder(cli)
84
85	fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
86	var resp ArithAddResp
87	if err := dec.Decode(&resp); err != nil {
88		t.Fatalf("Decode after no params: %s", err)
89	}
90	if resp.Error == nil {
91		t.Fatalf("Expected error, got nil")
92	}
93}
94
95func TestServerEmptyMessage(t *testing.T) {
96	cli, srv := net.Pipe()
97	defer cli.Close()
98	go ServeConn(srv)
99	dec := json.NewDecoder(cli)
100
101	fmt.Fprintf(cli, "{}")
102	var resp ArithAddResp
103	if err := dec.Decode(&resp); err != nil {
104		t.Fatalf("Decode after empty: %s", err)
105	}
106	if resp.Error == nil {
107		t.Fatalf("Expected error, got nil")
108	}
109}
110
111func TestServer(t *testing.T) {
112	cli, srv := net.Pipe()
113	defer cli.Close()
114	go ServeConn(srv)
115	dec := json.NewDecoder(cli)
116
117	// Send hand-coded requests to server, parse responses.
118	for i := 0; i < 10; i++ {
119		fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
120		var resp ArithAddResp
121		err := dec.Decode(&resp)
122		if err != nil {
123			t.Fatalf("Decode: %s", err)
124		}
125		if resp.Error != nil {
126			t.Fatalf("resp.Error: %s", resp.Error)
127		}
128		if resp.Id.(string) != string(rune(i)) {
129			t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(rune(i)))
130		}
131		if resp.Result.C != 2*i+1 {
132			t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
133		}
134	}
135}
136
137func TestClient(t *testing.T) {
138	// Assume server is okay (TestServer is above).
139	// Test client against server.
140	cli, srv := net.Pipe()
141	go ServeConn(srv)
142
143	client := NewClient(cli)
144	defer client.Close()
145
146	// Synchronous calls
147	args := &Args{7, 8}
148	reply := new(Reply)
149	err := client.Call("Arith.Add", args, reply)
150	if err != nil {
151		t.Errorf("Add: expected no error but got string %q", err.Error())
152	}
153	if reply.C != args.A+args.B {
154		t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
155	}
156
157	args = &Args{7, 8}
158	reply = new(Reply)
159	err = client.Call("Arith.Mul", args, reply)
160	if err != nil {
161		t.Errorf("Mul: expected no error but got string %q", err.Error())
162	}
163	if reply.C != args.A*args.B {
164		t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
165	}
166
167	// Out of order.
168	args = &Args{7, 8}
169	mulReply := new(Reply)
170	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
171	addReply := new(Reply)
172	addCall := client.Go("Arith.Add", args, addReply, nil)
173
174	addCall = <-addCall.Done
175	if addCall.Error != nil {
176		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
177	}
178	if addReply.C != args.A+args.B {
179		t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
180	}
181
182	mulCall = <-mulCall.Done
183	if mulCall.Error != nil {
184		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
185	}
186	if mulReply.C != args.A*args.B {
187		t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
188	}
189
190	// Error test
191	args = &Args{7, 0}
192	reply = new(Reply)
193	err = client.Call("Arith.Div", args, reply)
194	// expect an error: zero divide
195	if err == nil {
196		t.Error("Div: expected error")
197	} else if err.Error() != "divide by zero" {
198		t.Error("Div: expected divide by zero error; got", err)
199	}
200}
201
202func TestBuiltinTypes(t *testing.T) {
203	cli, srv := net.Pipe()
204	go ServeConn(srv)
205
206	client := NewClient(cli)
207	defer client.Close()
208
209	// Map
210	arg := 7
211	replyMap := map[int]int{}
212	err := client.Call("BuiltinTypes.Map", arg, &replyMap)
213	if err != nil {
214		t.Errorf("Map: expected no error but got string %q", err.Error())
215	}
216	if replyMap[arg] != arg {
217		t.Errorf("Map: expected %d got %d", arg, replyMap[arg])
218	}
219
220	// Slice
221	replySlice := []int{}
222	err = client.Call("BuiltinTypes.Slice", arg, &replySlice)
223	if err != nil {
224		t.Errorf("Slice: expected no error but got string %q", err.Error())
225	}
226	if e := []int{arg}; !reflect.DeepEqual(replySlice, e) {
227		t.Errorf("Slice: expected %v got %v", e, replySlice)
228	}
229
230	// Array
231	replyArray := [1]int{}
232	err = client.Call("BuiltinTypes.Array", arg, &replyArray)
233	if err != nil {
234		t.Errorf("Array: expected no error but got string %q", err.Error())
235	}
236	if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) {
237		t.Errorf("Array: expected %v got %v", e, replyArray)
238	}
239}
240
241func TestMalformedInput(t *testing.T) {
242	cli, srv := net.Pipe()
243	go cli.Write([]byte(`{id:1}`)) // invalid json
244	ServeConn(srv)                 // must return, not loop
245}
246
247func TestMalformedOutput(t *testing.T) {
248	cli, srv := net.Pipe()
249	go srv.Write([]byte(`{"id":0,"result":null,"error":null}`))
250	go io.ReadAll(srv)
251
252	client := NewClient(cli)
253	defer client.Close()
254
255	args := &Args{7, 8}
256	reply := new(Reply)
257	err := client.Call("Arith.Add", args, reply)
258	if err == nil {
259		t.Error("expected error")
260	}
261}
262
263func TestServerErrorHasNullResult(t *testing.T) {
264	var out strings.Builder
265	sc := NewServerCodec(struct {
266		io.Reader
267		io.Writer
268		io.Closer
269	}{
270		Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`),
271		Writer: &out,
272		Closer: io.NopCloser(nil),
273	})
274	r := new(rpc.Request)
275	if err := sc.ReadRequestHeader(r); err != nil {
276		t.Fatal(err)
277	}
278	const valueText = "the value we don't want to see"
279	const errorText = "some error"
280	err := sc.WriteResponse(&rpc.Response{
281		ServiceMethod: "Method",
282		Seq:           1,
283		Error:         errorText,
284	}, valueText)
285	if err != nil {
286		t.Fatal(err)
287	}
288	if !strings.Contains(out.String(), errorText) {
289		t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out)
290	}
291	if strings.Contains(out.String(), valueText) {
292		t.Errorf("Response contains both an error and value: %s", &out)
293	}
294}
295
296func TestUnexpectedError(t *testing.T) {
297	cli, srv := myPipe()
298	go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
299	ServeConn(srv)                                                    // must return, not loop
300}
301
302// Copied from package net.
303func myPipe() (*pipe, *pipe) {
304	r1, w1 := io.Pipe()
305	r2, w2 := io.Pipe()
306
307	return &pipe{r1, w2}, &pipe{r2, w1}
308}
309
310type pipe struct {
311	*io.PipeReader
312	*io.PipeWriter
313}
314
315type pipeAddr int
316
317func (pipeAddr) Network() string {
318	return "pipe"
319}
320
321func (pipeAddr) String() string {
322	return "pipe"
323}
324
325func (p *pipe) Close() error {
326	err := p.PipeReader.Close()
327	err1 := p.PipeWriter.Close()
328	if err == nil {
329		err = err1
330	}
331	return err
332}
333
334func (p *pipe) LocalAddr() net.Addr {
335	return pipeAddr(0)
336}
337
338func (p *pipe) RemoteAddr() net.Addr {
339	return pipeAddr(0)
340}
341
342func (p *pipe) SetTimeout(nsec int64) error {
343	return errors.New("net.Pipe does not support timeouts")
344}
345
346func (p *pipe) SetReadTimeout(nsec int64) error {
347	return errors.New("net.Pipe does not support timeouts")
348}
349
350func (p *pipe) SetWriteTimeout(nsec int64) error {
351	return errors.New("net.Pipe does not support timeouts")
352}
353