xref: /aosp_15_r20/external/golang-protobuf/proto/decode.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"google.golang.org/protobuf/encoding/protowire"
9	"google.golang.org/protobuf/internal/encoding/messageset"
10	"google.golang.org/protobuf/internal/errors"
11	"google.golang.org/protobuf/internal/flags"
12	"google.golang.org/protobuf/internal/genid"
13	"google.golang.org/protobuf/internal/pragma"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	"google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17)
18
19// UnmarshalOptions configures the unmarshaler.
20//
21// Example usage:
22//
23//	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
24type UnmarshalOptions struct {
25	pragma.NoUnkeyedLiterals
26
27	// Merge merges the input into the destination message.
28	// The default behavior is to always reset the message before unmarshaling,
29	// unless Merge is specified.
30	Merge bool
31
32	// AllowPartial accepts input for messages that will result in missing
33	// required fields. If AllowPartial is false (the default), Unmarshal will
34	// return an error if there are any missing required fields.
35	AllowPartial bool
36
37	// If DiscardUnknown is set, unknown fields are ignored.
38	DiscardUnknown bool
39
40	// Resolver is used for looking up types when unmarshaling extension fields.
41	// If nil, this defaults to using protoregistry.GlobalTypes.
42	Resolver interface {
43		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
44		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
45	}
46
47	// RecursionLimit limits how deeply messages may be nested.
48	// If zero, a default limit is applied.
49	RecursionLimit int
50}
51
52// Unmarshal parses the wire-format message in b and places the result in m.
53// The provided message must be mutable (e.g., a non-nil pointer to a message).
54func Unmarshal(b []byte, m Message) error {
55	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
56	return err
57}
58
59// Unmarshal parses the wire-format message in b and places the result in m.
60// The provided message must be mutable (e.g., a non-nil pointer to a message).
61func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
62	if o.RecursionLimit == 0 {
63		o.RecursionLimit = protowire.DefaultRecursionLimit
64	}
65	_, err := o.unmarshal(b, m.ProtoReflect())
66	return err
67}
68
69// UnmarshalState parses a wire-format message and places the result in m.
70//
71// This method permits fine-grained control over the unmarshaler.
72// Most users should use Unmarshal instead.
73func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
74	if o.RecursionLimit == 0 {
75		o.RecursionLimit = protowire.DefaultRecursionLimit
76	}
77	return o.unmarshal(in.Buf, in.Message)
78}
79
80// unmarshal is a centralized function that all unmarshal operations go through.
81// For profiling purposes, avoid changing the name of this function or
82// introducing other code paths for unmarshal that do not go through this.
83func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
84	if o.Resolver == nil {
85		o.Resolver = protoregistry.GlobalTypes
86	}
87	if !o.Merge {
88		Reset(m.Interface())
89	}
90	allowPartial := o.AllowPartial
91	o.Merge = true
92	o.AllowPartial = true
93	methods := protoMethods(m)
94	if methods != nil && methods.Unmarshal != nil &&
95		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
96		in := protoiface.UnmarshalInput{
97			Message:  m,
98			Buf:      b,
99			Resolver: o.Resolver,
100			Depth:    o.RecursionLimit,
101		}
102		if o.DiscardUnknown {
103			in.Flags |= protoiface.UnmarshalDiscardUnknown
104		}
105		out, err = methods.Unmarshal(in)
106	} else {
107		o.RecursionLimit--
108		if o.RecursionLimit < 0 {
109			return out, errors.New("exceeded max recursion depth")
110		}
111		err = o.unmarshalMessageSlow(b, m)
112	}
113	if err != nil {
114		return out, err
115	}
116	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
117		return out, nil
118	}
119	return out, checkInitialized(m)
120}
121
122func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
123	_, err := o.unmarshal(b, m)
124	return err
125}
126
127func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
128	md := m.Descriptor()
129	if messageset.IsMessageSet(md) {
130		return o.unmarshalMessageSet(b, m)
131	}
132	fields := md.Fields()
133	for len(b) > 0 {
134		// Parse the tag (field number and wire type).
135		num, wtyp, tagLen := protowire.ConsumeTag(b)
136		if tagLen < 0 {
137			return errDecode
138		}
139		if num > protowire.MaxValidNumber {
140			return errDecode
141		}
142
143		// Find the field descriptor for this field number.
144		fd := fields.ByNumber(num)
145		if fd == nil && md.ExtensionRanges().Has(num) {
146			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
147			if err != nil && err != protoregistry.NotFound {
148				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
149			}
150			if extType != nil {
151				fd = extType.TypeDescriptor()
152			}
153		}
154		var err error
155		if fd == nil {
156			err = errUnknown
157		} else if flags.ProtoLegacy {
158			if fd.IsWeak() && fd.Message().IsPlaceholder() {
159				err = errUnknown // weak referent is not linked in
160			}
161		}
162
163		// Parse the field value.
164		var valLen int
165		switch {
166		case err != nil:
167		case fd.IsList():
168			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
169		case fd.IsMap():
170			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
171		default:
172			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
173		}
174		if err != nil {
175			if err != errUnknown {
176				return err
177			}
178			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
179			if valLen < 0 {
180				return errDecode
181			}
182			if !o.DiscardUnknown {
183				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
184			}
185		}
186		b = b[tagLen+valLen:]
187	}
188	return nil
189}
190
191func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
192	v, n, err := o.unmarshalScalar(b, wtyp, fd)
193	if err != nil {
194		return 0, err
195	}
196	switch fd.Kind() {
197	case protoreflect.GroupKind, protoreflect.MessageKind:
198		m2 := m.Mutable(fd).Message()
199		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
200			return n, err
201		}
202	default:
203		// Non-message scalars replace the previous value.
204		m.Set(fd, v)
205	}
206	return n, nil
207}
208
209func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
210	if wtyp != protowire.BytesType {
211		return 0, errUnknown
212	}
213	b, n = protowire.ConsumeBytes(b)
214	if n < 0 {
215		return 0, errDecode
216	}
217	var (
218		keyField = fd.MapKey()
219		valField = fd.MapValue()
220		key      protoreflect.Value
221		val      protoreflect.Value
222		haveKey  bool
223		haveVal  bool
224	)
225	switch valField.Kind() {
226	case protoreflect.GroupKind, protoreflect.MessageKind:
227		val = mapv.NewValue()
228	}
229	// Map entries are represented as a two-element message with fields
230	// containing the key and value.
231	for len(b) > 0 {
232		num, wtyp, n := protowire.ConsumeTag(b)
233		if n < 0 {
234			return 0, errDecode
235		}
236		if num > protowire.MaxValidNumber {
237			return 0, errDecode
238		}
239		b = b[n:]
240		err = errUnknown
241		switch num {
242		case genid.MapEntry_Key_field_number:
243			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
244			if err != nil {
245				break
246			}
247			haveKey = true
248		case genid.MapEntry_Value_field_number:
249			var v protoreflect.Value
250			v, n, err = o.unmarshalScalar(b, wtyp, valField)
251			if err != nil {
252				break
253			}
254			switch valField.Kind() {
255			case protoreflect.GroupKind, protoreflect.MessageKind:
256				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
257					return 0, err
258				}
259			default:
260				val = v
261			}
262			haveVal = true
263		}
264		if err == errUnknown {
265			n = protowire.ConsumeFieldValue(num, wtyp, b)
266			if n < 0 {
267				return 0, errDecode
268			}
269		} else if err != nil {
270			return 0, err
271		}
272		b = b[n:]
273	}
274	// Every map entry should have entries for key and value, but this is not strictly required.
275	if !haveKey {
276		key = keyField.Default()
277	}
278	if !haveVal {
279		switch valField.Kind() {
280		case protoreflect.GroupKind, protoreflect.MessageKind:
281		default:
282			val = valField.Default()
283		}
284	}
285	mapv.Set(key.MapKey(), val)
286	return n, nil
287}
288
289// errUnknown is used internally to indicate fields which should be added
290// to the unknown field set of a message. It is never returned from an exported
291// function.
292var errUnknown = errors.New("BUG: internal error (unknown)")
293
294var errDecode = errors.New("cannot parse invalid wire-format data")
295