1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package json
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"log"
12	"net"
13	"net/http"
14	"net/http/httptest"
15	"path"
16	"reflect"
17	"runtime"
18	"runtime/debug"
19	"strings"
20	"testing"
21)
22
23// TODO(https://go.dev/issue/52751): Replace with native testing support.
24
25// CaseName is a case name annotated with a file and line.
26type CaseName struct {
27	Name  string
28	Where CasePos
29}
30
31// Name annotates a case name with the file and line of the caller.
32func Name(s string) (c CaseName) {
33	c.Name = s
34	runtime.Callers(2, c.Where.pc[:])
35	return c
36}
37
38// CasePos represents a file and line number.
39type CasePos struct{ pc [1]uintptr }
40
41func (pos CasePos) String() string {
42	frames := runtime.CallersFrames(pos.pc[:])
43	frame, _ := frames.Next()
44	return fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line)
45}
46
47// Test values for the stream test.
48// One of each JSON kind.
49var streamTest = []any{
50	0.1,
51	"hello",
52	nil,
53	true,
54	false,
55	[]any{"a", "b", "c"},
56	map[string]any{"K": "Kelvin", "ß": "long s"},
57	3.14, // another value to make sure something can follow map
58}
59
60var streamEncoded = `0.1
61"hello"
62null
63true
64false
65["a","b","c"]
66{"ß":"long s","K":"Kelvin"}
673.14
68`
69
70func TestEncoder(t *testing.T) {
71	for i := 0; i <= len(streamTest); i++ {
72		var buf strings.Builder
73		enc := NewEncoder(&buf)
74		// Check that enc.SetIndent("", "") turns off indentation.
75		enc.SetIndent(">", ".")
76		enc.SetIndent("", "")
77		for j, v := range streamTest[0:i] {
78			if err := enc.Encode(v); err != nil {
79				t.Fatalf("#%d.%d Encode error: %v", i, j, err)
80			}
81		}
82		if have, want := buf.String(), nlines(streamEncoded, i); have != want {
83			t.Errorf("encoding %d items: mismatch:", i)
84			diff(t, []byte(have), []byte(want))
85			break
86		}
87	}
88}
89
90func TestEncoderErrorAndReuseEncodeState(t *testing.T) {
91	// Disable the GC temporarily to prevent encodeState's in Pool being cleaned away during the test.
92	percent := debug.SetGCPercent(-1)
93	defer debug.SetGCPercent(percent)
94
95	// Trigger an error in Marshal with cyclic data.
96	type Dummy struct {
97		Name string
98		Next *Dummy
99	}
100	dummy := Dummy{Name: "Dummy"}
101	dummy.Next = &dummy
102
103	var buf bytes.Buffer
104	enc := NewEncoder(&buf)
105	if err := enc.Encode(dummy); err == nil {
106		t.Errorf("Encode(dummy) error: got nil, want non-nil")
107	}
108
109	type Data struct {
110		A string
111		I int
112	}
113	want := Data{A: "a", I: 1}
114	if err := enc.Encode(want); err != nil {
115		t.Errorf("Marshal error: %v", err)
116	}
117
118	var got Data
119	if err := Unmarshal(buf.Bytes(), &got); err != nil {
120		t.Errorf("Unmarshal error: %v", err)
121	}
122	if got != want {
123		t.Errorf("Marshal/Unmarshal roundtrip:\n\tgot:  %v\n\twant: %v", got, want)
124	}
125}
126
127var streamEncodedIndent = `0.1
128"hello"
129null
130true
131false
132[
133>."a",
134>."b",
135>."c"
136>]
137{
138>."ß": "long s",
139>."K": "Kelvin"
140>}
1413.14
142`
143
144func TestEncoderIndent(t *testing.T) {
145	var buf strings.Builder
146	enc := NewEncoder(&buf)
147	enc.SetIndent(">", ".")
148	for _, v := range streamTest {
149		enc.Encode(v)
150	}
151	if have, want := buf.String(), streamEncodedIndent; have != want {
152		t.Error("Encode mismatch:")
153		diff(t, []byte(have), []byte(want))
154	}
155}
156
157type strMarshaler string
158
159func (s strMarshaler) MarshalJSON() ([]byte, error) {
160	return []byte(s), nil
161}
162
163type strPtrMarshaler string
164
165func (s *strPtrMarshaler) MarshalJSON() ([]byte, error) {
166	return []byte(*s), nil
167}
168
169func TestEncoderSetEscapeHTML(t *testing.T) {
170	var c C
171	var ct CText
172	var tagStruct struct {
173		Valid   int `json:"<>&#! "`
174		Invalid int `json:"\\"`
175	}
176
177	// This case is particularly interesting, as we force the encoder to
178	// take the address of the Ptr field to use its MarshalJSON method. This
179	// is why the '&' is important.
180	marshalerStruct := &struct {
181		NonPtr strMarshaler
182		Ptr    strPtrMarshaler
183	}{`"<str>"`, `"<str>"`}
184
185	// https://golang.org/issue/34154
186	stringOption := struct {
187		Bar string `json:"bar,string"`
188	}{`<html>foobar</html>`}
189
190	tests := []struct {
191		CaseName
192		v          any
193		wantEscape string
194		want       string
195	}{
196		{Name("c"), c, `"\u003c\u0026\u003e"`, `"<&>"`},
197		{Name("ct"), ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`},
198		{Name(`"<&>"`), "<&>", `"\u003c\u0026\u003e"`, `"<&>"`},
199		{
200			Name("tagStruct"), tagStruct,
201			`{"\u003c\u003e\u0026#! ":0,"Invalid":0}`,
202			`{"<>&#! ":0,"Invalid":0}`,
203		},
204		{
205			Name(`"<str>"`), marshalerStruct,
206			`{"NonPtr":"\u003cstr\u003e","Ptr":"\u003cstr\u003e"}`,
207			`{"NonPtr":"<str>","Ptr":"<str>"}`,
208		},
209		{
210			Name("stringOption"), stringOption,
211			`{"bar":"\"\\u003chtml\\u003efoobar\\u003c/html\\u003e\""}`,
212			`{"bar":"\"<html>foobar</html>\""}`,
213		},
214	}
215	for _, tt := range tests {
216		t.Run(tt.Name, func(t *testing.T) {
217			var buf strings.Builder
218			enc := NewEncoder(&buf)
219			if err := enc.Encode(tt.v); err != nil {
220				t.Fatalf("%s: Encode(%s) error: %s", tt.Where, tt.Name, err)
221			}
222			if got := strings.TrimSpace(buf.String()); got != tt.wantEscape {
223				t.Errorf("%s: Encode(%s):\n\tgot:  %s\n\twant: %s", tt.Where, tt.Name, got, tt.wantEscape)
224			}
225			buf.Reset()
226			enc.SetEscapeHTML(false)
227			if err := enc.Encode(tt.v); err != nil {
228				t.Fatalf("%s: SetEscapeHTML(false) Encode(%s) error: %s", tt.Where, tt.Name, err)
229			}
230			if got := strings.TrimSpace(buf.String()); got != tt.want {
231				t.Errorf("%s: SetEscapeHTML(false) Encode(%s):\n\tgot:  %s\n\twant: %s",
232					tt.Where, tt.Name, got, tt.want)
233			}
234		})
235	}
236}
237
238func TestDecoder(t *testing.T) {
239	for i := 0; i <= len(streamTest); i++ {
240		// Use stream without newlines as input,
241		// just to stress the decoder even more.
242		// Our test input does not include back-to-back numbers.
243		// Otherwise stripping the newlines would
244		// merge two adjacent JSON values.
245		var buf bytes.Buffer
246		for _, c := range nlines(streamEncoded, i) {
247			if c != '\n' {
248				buf.WriteRune(c)
249			}
250		}
251		out := make([]any, i)
252		dec := NewDecoder(&buf)
253		for j := range out {
254			if err := dec.Decode(&out[j]); err != nil {
255				t.Fatalf("decode #%d/%d error: %v", j, i, err)
256			}
257		}
258		if !reflect.DeepEqual(out, streamTest[0:i]) {
259			t.Errorf("decoding %d items: mismatch:", i)
260			for j := range out {
261				if !reflect.DeepEqual(out[j], streamTest[j]) {
262					t.Errorf("#%d:\n\tgot:  %v\n\twant: %v", j, out[j], streamTest[j])
263				}
264			}
265			break
266		}
267	}
268}
269
270func TestDecoderBuffered(t *testing.T) {
271	r := strings.NewReader(`{"Name": "Gopher"} extra `)
272	var m struct {
273		Name string
274	}
275	d := NewDecoder(r)
276	err := d.Decode(&m)
277	if err != nil {
278		t.Fatal(err)
279	}
280	if m.Name != "Gopher" {
281		t.Errorf("Name = %s, want Gopher", m.Name)
282	}
283	rest, err := io.ReadAll(d.Buffered())
284	if err != nil {
285		t.Fatal(err)
286	}
287	if got, want := string(rest), " extra "; got != want {
288		t.Errorf("Remaining = %s, want %s", got, want)
289	}
290}
291
292func nlines(s string, n int) string {
293	if n <= 0 {
294		return ""
295	}
296	for i, c := range s {
297		if c == '\n' {
298			if n--; n == 0 {
299				return s[0 : i+1]
300			}
301		}
302	}
303	return s
304}
305
306func TestRawMessage(t *testing.T) {
307	var data struct {
308		X  float64
309		Id RawMessage
310		Y  float32
311	}
312	const raw = `["\u0056",null]`
313	const want = `{"X":0.1,"Id":["\u0056",null],"Y":0.2}`
314	err := Unmarshal([]byte(want), &data)
315	if err != nil {
316		t.Fatalf("Unmarshal error: %v", err)
317	}
318	if string([]byte(data.Id)) != raw {
319		t.Fatalf("Unmarshal:\n\tgot:  %s\n\twant: %s", []byte(data.Id), raw)
320	}
321	got, err := Marshal(&data)
322	if err != nil {
323		t.Fatalf("Marshal error: %v", err)
324	}
325	if string(got) != want {
326		t.Fatalf("Marshal:\n\tgot:  %s\n\twant: %s", got, want)
327	}
328}
329
330func TestNullRawMessage(t *testing.T) {
331	var data struct {
332		X     float64
333		Id    RawMessage
334		IdPtr *RawMessage
335		Y     float32
336	}
337	const want = `{"X":0.1,"Id":null,"IdPtr":null,"Y":0.2}`
338	err := Unmarshal([]byte(want), &data)
339	if err != nil {
340		t.Fatalf("Unmarshal error: %v", err)
341	}
342	if want, got := "null", string(data.Id); want != got {
343		t.Fatalf("Unmarshal:\n\tgot:  %s\n\twant: %s", got, want)
344	}
345	if data.IdPtr != nil {
346		t.Fatalf("pointer mismatch: got non-nil, want nil")
347	}
348	got, err := Marshal(&data)
349	if err != nil {
350		t.Fatalf("Marshal error: %v", err)
351	}
352	if string(got) != want {
353		t.Fatalf("Marshal:\n\tgot:  %s\n\twant: %s", got, want)
354	}
355}
356
357func TestBlocking(t *testing.T) {
358	tests := []struct {
359		CaseName
360		in string
361	}{
362		{Name(""), `{"x": 1}`},
363		{Name(""), `[1, 2, 3]`},
364	}
365	for _, tt := range tests {
366		t.Run(tt.Name, func(t *testing.T) {
367			r, w := net.Pipe()
368			go w.Write([]byte(tt.in))
369			var val any
370
371			// If Decode reads beyond what w.Write writes above,
372			// it will block, and the test will deadlock.
373			if err := NewDecoder(r).Decode(&val); err != nil {
374				t.Errorf("%s: NewDecoder(%s).Decode error: %v", tt.Where, tt.in, err)
375			}
376			r.Close()
377			w.Close()
378		})
379	}
380}
381
382type decodeThis struct {
383	v any
384}
385
386func TestDecodeInStream(t *testing.T) {
387	tests := []struct {
388		CaseName
389		json      string
390		expTokens []any
391	}{
392		// streaming token cases
393		{CaseName: Name(""), json: `10`, expTokens: []any{float64(10)}},
394		{CaseName: Name(""), json: ` [10] `, expTokens: []any{
395			Delim('['), float64(10), Delim(']')}},
396		{CaseName: Name(""), json: ` [false,10,"b"] `, expTokens: []any{
397			Delim('['), false, float64(10), "b", Delim(']')}},
398		{CaseName: Name(""), json: `{ "a": 1 }`, expTokens: []any{
399			Delim('{'), "a", float64(1), Delim('}')}},
400		{CaseName: Name(""), json: `{"a": 1, "b":"3"}`, expTokens: []any{
401			Delim('{'), "a", float64(1), "b", "3", Delim('}')}},
402		{CaseName: Name(""), json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
403			Delim('['),
404			Delim('{'), "a", float64(1), Delim('}'),
405			Delim('{'), "a", float64(2), Delim('}'),
406			Delim(']')}},
407		{CaseName: Name(""), json: `{"obj": {"a": 1}}`, expTokens: []any{
408			Delim('{'), "obj", Delim('{'), "a", float64(1), Delim('}'),
409			Delim('}')}},
410		{CaseName: Name(""), json: `{"obj": [{"a": 1}]}`, expTokens: []any{
411			Delim('{'), "obj", Delim('['),
412			Delim('{'), "a", float64(1), Delim('}'),
413			Delim(']'), Delim('}')}},
414
415		// streaming tokens with intermittent Decode()
416		{CaseName: Name(""), json: `{ "a": 1 }`, expTokens: []any{
417			Delim('{'), "a",
418			decodeThis{float64(1)},
419			Delim('}')}},
420		{CaseName: Name(""), json: ` [ { "a" : 1 } ] `, expTokens: []any{
421			Delim('['),
422			decodeThis{map[string]any{"a": float64(1)}},
423			Delim(']')}},
424		{CaseName: Name(""), json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
425			Delim('['),
426			decodeThis{map[string]any{"a": float64(1)}},
427			decodeThis{map[string]any{"a": float64(2)}},
428			Delim(']')}},
429		{CaseName: Name(""), json: `{ "obj" : [ { "a" : 1 } ] }`, expTokens: []any{
430			Delim('{'), "obj", Delim('['),
431			decodeThis{map[string]any{"a": float64(1)}},
432			Delim(']'), Delim('}')}},
433
434		{CaseName: Name(""), json: `{"obj": {"a": 1}}`, expTokens: []any{
435			Delim('{'), "obj",
436			decodeThis{map[string]any{"a": float64(1)}},
437			Delim('}')}},
438		{CaseName: Name(""), json: `{"obj": [{"a": 1}]}`, expTokens: []any{
439			Delim('{'), "obj",
440			decodeThis{[]any{
441				map[string]any{"a": float64(1)},
442			}},
443			Delim('}')}},
444		{CaseName: Name(""), json: ` [{"a": 1} {"a": 2}] `, expTokens: []any{
445			Delim('['),
446			decodeThis{map[string]any{"a": float64(1)}},
447			decodeThis{&SyntaxError{"expected comma after array element", 11}},
448		}},
449		{CaseName: Name(""), json: `{ "` + strings.Repeat("a", 513) + `" 1 }`, expTokens: []any{
450			Delim('{'), strings.Repeat("a", 513),
451			decodeThis{&SyntaxError{"expected colon after object key", 518}},
452		}},
453		{CaseName: Name(""), json: `{ "\a" }`, expTokens: []any{
454			Delim('{'),
455			&SyntaxError{"invalid character 'a' in string escape code", 3},
456		}},
457		{CaseName: Name(""), json: ` \a`, expTokens: []any{
458			&SyntaxError{"invalid character '\\\\' looking for beginning of value", 1},
459		}},
460	}
461	for _, tt := range tests {
462		t.Run(tt.Name, func(t *testing.T) {
463			dec := NewDecoder(strings.NewReader(tt.json))
464			for i, want := range tt.expTokens {
465				var got any
466				var err error
467
468				if dt, ok := want.(decodeThis); ok {
469					want = dt.v
470					err = dec.Decode(&got)
471				} else {
472					got, err = dec.Token()
473				}
474				if errWant, ok := want.(error); ok {
475					if err == nil || !reflect.DeepEqual(err, errWant) {
476						t.Fatalf("%s:\n\tinput: %s\n\tgot error:  %v\n\twant error: %v", tt.Where, tt.json, err, errWant)
477					}
478					break
479				} else if err != nil {
480					t.Fatalf("%s:\n\tinput: %s\n\tgot error:  %v\n\twant error: nil", tt.Where, tt.json, err)
481				}
482				if !reflect.DeepEqual(got, want) {
483					t.Fatalf("%s: token %d:\n\tinput: %s\n\tgot:  %T(%v)\n\twant: %T(%v)", tt.Where, i, tt.json, got, got, want, want)
484				}
485			}
486		})
487	}
488}
489
490// Test from golang.org/issue/11893
491func TestHTTPDecoding(t *testing.T) {
492	const raw = `{ "foo": "bar" }`
493
494	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
495		w.Write([]byte(raw))
496	}))
497	defer ts.Close()
498	res, err := http.Get(ts.URL)
499	if err != nil {
500		log.Fatalf("http.Get error: %v", err)
501	}
502	defer res.Body.Close()
503
504	foo := struct {
505		Foo string
506	}{}
507
508	d := NewDecoder(res.Body)
509	err = d.Decode(&foo)
510	if err != nil {
511		t.Fatalf("Decode error: %v", err)
512	}
513	if foo.Foo != "bar" {
514		t.Errorf(`Decode: got %q, want "bar"`, foo.Foo)
515	}
516
517	// make sure we get the EOF the second time
518	err = d.Decode(&foo)
519	if err != io.EOF {
520		t.Errorf("Decode error:\n\tgot:  %v\n\twant: io.EOF", err)
521	}
522}
523