xref: /aosp_15_r20/external/golang-protobuf/encoding/protojson/decode.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package protojson
6
7import (
8	"encoding/base64"
9	"fmt"
10	"math"
11	"strconv"
12	"strings"
13
14	"google.golang.org/protobuf/internal/encoding/json"
15	"google.golang.org/protobuf/internal/encoding/messageset"
16	"google.golang.org/protobuf/internal/errors"
17	"google.golang.org/protobuf/internal/flags"
18	"google.golang.org/protobuf/internal/genid"
19	"google.golang.org/protobuf/internal/pragma"
20	"google.golang.org/protobuf/internal/set"
21	"google.golang.org/protobuf/proto"
22	"google.golang.org/protobuf/reflect/protoreflect"
23	"google.golang.org/protobuf/reflect/protoregistry"
24)
25
26// Unmarshal reads the given []byte into the given proto.Message.
27// The provided message must be mutable (e.g., a non-nil pointer to a message).
28func Unmarshal(b []byte, m proto.Message) error {
29	return UnmarshalOptions{}.Unmarshal(b, m)
30}
31
32// UnmarshalOptions is a configurable JSON format parser.
33type UnmarshalOptions struct {
34	pragma.NoUnkeyedLiterals
35
36	// If AllowPartial is set, input for messages that will result in missing
37	// required fields will not return an error.
38	AllowPartial bool
39
40	// If DiscardUnknown is set, unknown fields are ignored.
41	DiscardUnknown bool
42
43	// Resolver is used for looking up types when unmarshaling
44	// google.protobuf.Any messages or extension fields.
45	// If nil, this defaults to using protoregistry.GlobalTypes.
46	Resolver interface {
47		protoregistry.MessageTypeResolver
48		protoregistry.ExtensionTypeResolver
49	}
50}
51
52// Unmarshal reads the given []byte and populates the given proto.Message
53// using options in the UnmarshalOptions object.
54// It will clear the message first before setting the fields.
55// If it returns an error, the given message may be partially set.
56// The provided message must be mutable (e.g., a non-nil pointer to a message).
57func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
58	return o.unmarshal(b, m)
59}
60
61// unmarshal is a centralized function that all unmarshal operations go through.
62// For profiling purposes, avoid changing the name of this function or
63// introducing other code paths for unmarshal that do not go through this.
64func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error {
65	proto.Reset(m)
66
67	if o.Resolver == nil {
68		o.Resolver = protoregistry.GlobalTypes
69	}
70
71	dec := decoder{json.NewDecoder(b), o}
72	if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil {
73		return err
74	}
75
76	// Check for EOF.
77	tok, err := dec.Read()
78	if err != nil {
79		return err
80	}
81	if tok.Kind() != json.EOF {
82		return dec.unexpectedTokenError(tok)
83	}
84
85	if o.AllowPartial {
86		return nil
87	}
88	return proto.CheckInitialized(m)
89}
90
91type decoder struct {
92	*json.Decoder
93	opts UnmarshalOptions
94}
95
96// newError returns an error object with position info.
97func (d decoder) newError(pos int, f string, x ...interface{}) error {
98	line, column := d.Position(pos)
99	head := fmt.Sprintf("(line %d:%d): ", line, column)
100	return errors.New(head+f, x...)
101}
102
103// unexpectedTokenError returns a syntax error for the given unexpected token.
104func (d decoder) unexpectedTokenError(tok json.Token) error {
105	return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString())
106}
107
108// syntaxError returns a syntax error for given position.
109func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
110	line, column := d.Position(pos)
111	head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
112	return errors.New(head+f, x...)
113}
114
115// unmarshalMessage unmarshals a message into the given protoreflect.Message.
116func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error {
117	if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil {
118		return unmarshal(d, m)
119	}
120
121	tok, err := d.Read()
122	if err != nil {
123		return err
124	}
125	if tok.Kind() != json.ObjectOpen {
126		return d.unexpectedTokenError(tok)
127	}
128
129	messageDesc := m.Descriptor()
130	if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
131		return errors.New("no support for proto1 MessageSets")
132	}
133
134	var seenNums set.Ints
135	var seenOneofs set.Ints
136	fieldDescs := messageDesc.Fields()
137	for {
138		// Read field name.
139		tok, err := d.Read()
140		if err != nil {
141			return err
142		}
143		switch tok.Kind() {
144		default:
145			return d.unexpectedTokenError(tok)
146		case json.ObjectClose:
147			return nil
148		case json.Name:
149			// Continue below.
150		}
151
152		name := tok.Name()
153		// Unmarshaling a non-custom embedded message in Any will contain the
154		// JSON field "@type" which should be skipped because it is not a field
155		// of the embedded message, but simply an artifact of the Any format.
156		if skipTypeURL && name == "@type" {
157			d.Read()
158			continue
159		}
160
161		// Get the FieldDescriptor.
162		var fd protoreflect.FieldDescriptor
163		if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
164			// Only extension names are in [name] format.
165			extName := protoreflect.FullName(name[1 : len(name)-1])
166			extType, err := d.opts.Resolver.FindExtensionByName(extName)
167			if err != nil && err != protoregistry.NotFound {
168				return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
169			}
170			if extType != nil {
171				fd = extType.TypeDescriptor()
172				if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
173					return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
174				}
175			}
176		} else {
177			// The name can either be the JSON name or the proto field name.
178			fd = fieldDescs.ByJSONName(name)
179			if fd == nil {
180				fd = fieldDescs.ByTextName(name)
181			}
182		}
183		if flags.ProtoLegacy {
184			if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
185				fd = nil // reset since the weak reference is not linked in
186			}
187		}
188
189		if fd == nil {
190			// Field is unknown.
191			if d.opts.DiscardUnknown {
192				if err := d.skipJSONValue(); err != nil {
193					return err
194				}
195				continue
196			}
197			return d.newError(tok.Pos(), "unknown field %v", tok.RawString())
198		}
199
200		// Do not allow duplicate fields.
201		num := uint64(fd.Number())
202		if seenNums.Has(num) {
203			return d.newError(tok.Pos(), "duplicate field %v", tok.RawString())
204		}
205		seenNums.Set(num)
206
207		// No need to set values for JSON null unless the field type is
208		// google.protobuf.Value or google.protobuf.NullValue.
209		if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
210			d.Read()
211			continue
212		}
213
214		switch {
215		case fd.IsList():
216			list := m.Mutable(fd).List()
217			if err := d.unmarshalList(list, fd); err != nil {
218				return err
219			}
220		case fd.IsMap():
221			mmap := m.Mutable(fd).Map()
222			if err := d.unmarshalMap(mmap, fd); err != nil {
223				return err
224			}
225		default:
226			// If field is a oneof, check if it has already been set.
227			if od := fd.ContainingOneof(); od != nil {
228				idx := uint64(od.Index())
229				if seenOneofs.Has(idx) {
230					return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName())
231				}
232				seenOneofs.Set(idx)
233			}
234
235			// Required or optional fields.
236			if err := d.unmarshalSingular(m, fd); err != nil {
237				return err
238			}
239		}
240	}
241}
242
243func isKnownValue(fd protoreflect.FieldDescriptor) bool {
244	md := fd.Message()
245	return md != nil && md.FullName() == genid.Value_message_fullname
246}
247
248func isNullValue(fd protoreflect.FieldDescriptor) bool {
249	ed := fd.Enum()
250	return ed != nil && ed.FullName() == genid.NullValue_enum_fullname
251}
252
253// unmarshalSingular unmarshals to the non-repeated field specified
254// by the given FieldDescriptor.
255func (d decoder) unmarshalSingular(m protoreflect.Message, fd protoreflect.FieldDescriptor) error {
256	var val protoreflect.Value
257	var err error
258	switch fd.Kind() {
259	case protoreflect.MessageKind, protoreflect.GroupKind:
260		val = m.NewField(fd)
261		err = d.unmarshalMessage(val.Message(), false)
262	default:
263		val, err = d.unmarshalScalar(fd)
264	}
265
266	if err != nil {
267		return err
268	}
269	m.Set(fd, val)
270	return nil
271}
272
273// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
274// the given FieldDescriptor.
275func (d decoder) unmarshalScalar(fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
276	const b32 int = 32
277	const b64 int = 64
278
279	tok, err := d.Read()
280	if err != nil {
281		return protoreflect.Value{}, err
282	}
283
284	kind := fd.Kind()
285	switch kind {
286	case protoreflect.BoolKind:
287		if tok.Kind() == json.Bool {
288			return protoreflect.ValueOfBool(tok.Bool()), nil
289		}
290
291	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
292		if v, ok := unmarshalInt(tok, b32); ok {
293			return v, nil
294		}
295
296	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
297		if v, ok := unmarshalInt(tok, b64); ok {
298			return v, nil
299		}
300
301	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
302		if v, ok := unmarshalUint(tok, b32); ok {
303			return v, nil
304		}
305
306	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
307		if v, ok := unmarshalUint(tok, b64); ok {
308			return v, nil
309		}
310
311	case protoreflect.FloatKind:
312		if v, ok := unmarshalFloat(tok, b32); ok {
313			return v, nil
314		}
315
316	case protoreflect.DoubleKind:
317		if v, ok := unmarshalFloat(tok, b64); ok {
318			return v, nil
319		}
320
321	case protoreflect.StringKind:
322		if tok.Kind() == json.String {
323			return protoreflect.ValueOfString(tok.ParsedString()), nil
324		}
325
326	case protoreflect.BytesKind:
327		if v, ok := unmarshalBytes(tok); ok {
328			return v, nil
329		}
330
331	case protoreflect.EnumKind:
332		if v, ok := unmarshalEnum(tok, fd); ok {
333			return v, nil
334		}
335
336	default:
337		panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind))
338	}
339
340	return protoreflect.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString())
341}
342
343func unmarshalInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
344	switch tok.Kind() {
345	case json.Number:
346		return getInt(tok, bitSize)
347
348	case json.String:
349		// Decode number from string.
350		s := strings.TrimSpace(tok.ParsedString())
351		if len(s) != len(tok.ParsedString()) {
352			return protoreflect.Value{}, false
353		}
354		dec := json.NewDecoder([]byte(s))
355		tok, err := dec.Read()
356		if err != nil {
357			return protoreflect.Value{}, false
358		}
359		return getInt(tok, bitSize)
360	}
361	return protoreflect.Value{}, false
362}
363
364func getInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
365	n, ok := tok.Int(bitSize)
366	if !ok {
367		return protoreflect.Value{}, false
368	}
369	if bitSize == 32 {
370		return protoreflect.ValueOfInt32(int32(n)), true
371	}
372	return protoreflect.ValueOfInt64(n), true
373}
374
375func unmarshalUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
376	switch tok.Kind() {
377	case json.Number:
378		return getUint(tok, bitSize)
379
380	case json.String:
381		// Decode number from string.
382		s := strings.TrimSpace(tok.ParsedString())
383		if len(s) != len(tok.ParsedString()) {
384			return protoreflect.Value{}, false
385		}
386		dec := json.NewDecoder([]byte(s))
387		tok, err := dec.Read()
388		if err != nil {
389			return protoreflect.Value{}, false
390		}
391		return getUint(tok, bitSize)
392	}
393	return protoreflect.Value{}, false
394}
395
396func getUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
397	n, ok := tok.Uint(bitSize)
398	if !ok {
399		return protoreflect.Value{}, false
400	}
401	if bitSize == 32 {
402		return protoreflect.ValueOfUint32(uint32(n)), true
403	}
404	return protoreflect.ValueOfUint64(n), true
405}
406
407func unmarshalFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
408	switch tok.Kind() {
409	case json.Number:
410		return getFloat(tok, bitSize)
411
412	case json.String:
413		s := tok.ParsedString()
414		switch s {
415		case "NaN":
416			if bitSize == 32 {
417				return protoreflect.ValueOfFloat32(float32(math.NaN())), true
418			}
419			return protoreflect.ValueOfFloat64(math.NaN()), true
420		case "Infinity":
421			if bitSize == 32 {
422				return protoreflect.ValueOfFloat32(float32(math.Inf(+1))), true
423			}
424			return protoreflect.ValueOfFloat64(math.Inf(+1)), true
425		case "-Infinity":
426			if bitSize == 32 {
427				return protoreflect.ValueOfFloat32(float32(math.Inf(-1))), true
428			}
429			return protoreflect.ValueOfFloat64(math.Inf(-1)), true
430		}
431
432		// Decode number from string.
433		if len(s) != len(strings.TrimSpace(s)) {
434			return protoreflect.Value{}, false
435		}
436		dec := json.NewDecoder([]byte(s))
437		tok, err := dec.Read()
438		if err != nil {
439			return protoreflect.Value{}, false
440		}
441		return getFloat(tok, bitSize)
442	}
443	return protoreflect.Value{}, false
444}
445
446func getFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
447	n, ok := tok.Float(bitSize)
448	if !ok {
449		return protoreflect.Value{}, false
450	}
451	if bitSize == 32 {
452		return protoreflect.ValueOfFloat32(float32(n)), true
453	}
454	return protoreflect.ValueOfFloat64(n), true
455}
456
457func unmarshalBytes(tok json.Token) (protoreflect.Value, bool) {
458	if tok.Kind() != json.String {
459		return protoreflect.Value{}, false
460	}
461
462	s := tok.ParsedString()
463	enc := base64.StdEncoding
464	if strings.ContainsAny(s, "-_") {
465		enc = base64.URLEncoding
466	}
467	if len(s)%4 != 0 {
468		enc = enc.WithPadding(base64.NoPadding)
469	}
470	b, err := enc.DecodeString(s)
471	if err != nil {
472		return protoreflect.Value{}, false
473	}
474	return protoreflect.ValueOfBytes(b), true
475}
476
477func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.Value, bool) {
478	switch tok.Kind() {
479	case json.String:
480		// Lookup EnumNumber based on name.
481		s := tok.ParsedString()
482		if enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)); enumVal != nil {
483			return protoreflect.ValueOfEnum(enumVal.Number()), true
484		}
485
486	case json.Number:
487		if n, ok := tok.Int(32); ok {
488			return protoreflect.ValueOfEnum(protoreflect.EnumNumber(n)), true
489		}
490
491	case json.Null:
492		// This is only valid for google.protobuf.NullValue.
493		if isNullValue(fd) {
494			return protoreflect.ValueOfEnum(0), true
495		}
496	}
497
498	return protoreflect.Value{}, false
499}
500
501func (d decoder) unmarshalList(list protoreflect.List, fd protoreflect.FieldDescriptor) error {
502	tok, err := d.Read()
503	if err != nil {
504		return err
505	}
506	if tok.Kind() != json.ArrayOpen {
507		return d.unexpectedTokenError(tok)
508	}
509
510	switch fd.Kind() {
511	case protoreflect.MessageKind, protoreflect.GroupKind:
512		for {
513			tok, err := d.Peek()
514			if err != nil {
515				return err
516			}
517
518			if tok.Kind() == json.ArrayClose {
519				d.Read()
520				return nil
521			}
522
523			val := list.NewElement()
524			if err := d.unmarshalMessage(val.Message(), false); err != nil {
525				return err
526			}
527			list.Append(val)
528		}
529	default:
530		for {
531			tok, err := d.Peek()
532			if err != nil {
533				return err
534			}
535
536			if tok.Kind() == json.ArrayClose {
537				d.Read()
538				return nil
539			}
540
541			val, err := d.unmarshalScalar(fd)
542			if err != nil {
543				return err
544			}
545			list.Append(val)
546		}
547	}
548
549	return nil
550}
551
552func (d decoder) unmarshalMap(mmap protoreflect.Map, fd protoreflect.FieldDescriptor) error {
553	tok, err := d.Read()
554	if err != nil {
555		return err
556	}
557	if tok.Kind() != json.ObjectOpen {
558		return d.unexpectedTokenError(tok)
559	}
560
561	// Determine ahead whether map entry is a scalar type or a message type in
562	// order to call the appropriate unmarshalMapValue func inside the for loop
563	// below.
564	var unmarshalMapValue func() (protoreflect.Value, error)
565	switch fd.MapValue().Kind() {
566	case protoreflect.MessageKind, protoreflect.GroupKind:
567		unmarshalMapValue = func() (protoreflect.Value, error) {
568			val := mmap.NewValue()
569			if err := d.unmarshalMessage(val.Message(), false); err != nil {
570				return protoreflect.Value{}, err
571			}
572			return val, nil
573		}
574	default:
575		unmarshalMapValue = func() (protoreflect.Value, error) {
576			return d.unmarshalScalar(fd.MapValue())
577		}
578	}
579
580Loop:
581	for {
582		// Read field name.
583		tok, err := d.Read()
584		if err != nil {
585			return err
586		}
587		switch tok.Kind() {
588		default:
589			return d.unexpectedTokenError(tok)
590		case json.ObjectClose:
591			break Loop
592		case json.Name:
593			// Continue.
594		}
595
596		// Unmarshal field name.
597		pkey, err := d.unmarshalMapKey(tok, fd.MapKey())
598		if err != nil {
599			return err
600		}
601
602		// Check for duplicate field name.
603		if mmap.Has(pkey) {
604			return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString())
605		}
606
607		// Read and unmarshal field value.
608		pval, err := unmarshalMapValue()
609		if err != nil {
610			return err
611		}
612
613		mmap.Set(pkey, pval)
614	}
615
616	return nil
617}
618
619// unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey.
620// A map key type is any integral or string type.
621func (d decoder) unmarshalMapKey(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.MapKey, error) {
622	const b32 = 32
623	const b64 = 64
624	const base10 = 10
625
626	name := tok.Name()
627	kind := fd.Kind()
628	switch kind {
629	case protoreflect.StringKind:
630		return protoreflect.ValueOfString(name).MapKey(), nil
631
632	case protoreflect.BoolKind:
633		switch name {
634		case "true":
635			return protoreflect.ValueOfBool(true).MapKey(), nil
636		case "false":
637			return protoreflect.ValueOfBool(false).MapKey(), nil
638		}
639
640	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
641		if n, err := strconv.ParseInt(name, base10, b32); err == nil {
642			return protoreflect.ValueOfInt32(int32(n)).MapKey(), nil
643		}
644
645	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
646		if n, err := strconv.ParseInt(name, base10, b64); err == nil {
647			return protoreflect.ValueOfInt64(int64(n)).MapKey(), nil
648		}
649
650	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
651		if n, err := strconv.ParseUint(name, base10, b32); err == nil {
652			return protoreflect.ValueOfUint32(uint32(n)).MapKey(), nil
653		}
654
655	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
656		if n, err := strconv.ParseUint(name, base10, b64); err == nil {
657			return protoreflect.ValueOfUint64(uint64(n)).MapKey(), nil
658		}
659
660	default:
661		panic(fmt.Sprintf("invalid kind for map key: %v", kind))
662	}
663
664	return protoreflect.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString())
665}
666