xref: /aosp_15_r20/external/golang-protobuf/internal/impl/decode.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"math/bits"
9
10	"google.golang.org/protobuf/encoding/protowire"
11	"google.golang.org/protobuf/internal/errors"
12	"google.golang.org/protobuf/internal/flags"
13	"google.golang.org/protobuf/proto"
14	"google.golang.org/protobuf/reflect/protoreflect"
15	"google.golang.org/protobuf/reflect/protoregistry"
16	"google.golang.org/protobuf/runtime/protoiface"
17)
18
19var errDecode = errors.New("cannot parse invalid wire-format data")
20var errRecursionDepth = errors.New("exceeded maximum recursion depth")
21
22type unmarshalOptions struct {
23	flags    protoiface.UnmarshalInputFlags
24	resolver interface {
25		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
26		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
27	}
28	depth int
29}
30
31func (o unmarshalOptions) Options() proto.UnmarshalOptions {
32	return proto.UnmarshalOptions{
33		Merge:          true,
34		AllowPartial:   true,
35		DiscardUnknown: o.DiscardUnknown(),
36		Resolver:       o.resolver,
37	}
38}
39
40func (o unmarshalOptions) DiscardUnknown() bool {
41	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
42}
43
44func (o unmarshalOptions) IsDefault() bool {
45	return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
46}
47
48var lazyUnmarshalOptions = unmarshalOptions{
49	resolver: protoregistry.GlobalTypes,
50	depth:    protowire.DefaultRecursionLimit,
51}
52
53type unmarshalOutput struct {
54	n           int // number of bytes consumed
55	initialized bool
56}
57
58// unmarshal is protoreflect.Methods.Unmarshal.
59func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
60	var p pointer
61	if ms, ok := in.Message.(*messageState); ok {
62		p = ms.pointer()
63	} else {
64		p = in.Message.(*messageReflectWrapper).pointer()
65	}
66	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
67		flags:    in.Flags,
68		resolver: in.Resolver,
69		depth:    in.Depth,
70	})
71	var flags protoiface.UnmarshalOutputFlags
72	if out.initialized {
73		flags |= protoiface.UnmarshalInitialized
74	}
75	return protoiface.UnmarshalOutput{
76		Flags: flags,
77	}, err
78}
79
80// errUnknown is returned during unmarshaling to indicate a parse error that
81// should result in a field being placed in the unknown fields section (for example,
82// when the wire type doesn't match) as opposed to the entire unmarshal operation
83// failing (for example, when a field extends past the available input).
84//
85// This is a sentinel error which should never be visible to the user.
86var errUnknown = errors.New("unknown")
87
88func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
89	mi.init()
90	opts.depth--
91	if opts.depth < 0 {
92		return out, errRecursionDepth
93	}
94	if flags.ProtoLegacy && mi.isMessageSet {
95		return unmarshalMessageSet(mi, b, p, opts)
96	}
97	initialized := true
98	var requiredMask uint64
99	var exts *map[int32]ExtensionField
100	start := len(b)
101	for len(b) > 0 {
102		// Parse the tag (field number and wire type).
103		var tag uint64
104		if b[0] < 0x80 {
105			tag = uint64(b[0])
106			b = b[1:]
107		} else if len(b) >= 2 && b[1] < 128 {
108			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
109			b = b[2:]
110		} else {
111			var n int
112			tag, n = protowire.ConsumeVarint(b)
113			if n < 0 {
114				return out, errDecode
115			}
116			b = b[n:]
117		}
118		var num protowire.Number
119		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
120			return out, errDecode
121		} else {
122			num = protowire.Number(n)
123		}
124		wtyp := protowire.Type(tag & 7)
125
126		if wtyp == protowire.EndGroupType {
127			if num != groupTag {
128				return out, errDecode
129			}
130			groupTag = 0
131			break
132		}
133
134		var f *coderFieldInfo
135		if int(num) < len(mi.denseCoderFields) {
136			f = mi.denseCoderFields[num]
137		} else {
138			f = mi.coderFields[num]
139		}
140		var n int
141		err := errUnknown
142		switch {
143		case f != nil:
144			if f.funcs.unmarshal == nil {
145				break
146			}
147			var o unmarshalOutput
148			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
149			n = o.n
150			if err != nil {
151				break
152			}
153			requiredMask |= f.validation.requiredBit
154			if f.funcs.isInit != nil && !o.initialized {
155				initialized = false
156			}
157		default:
158			// Possible extension.
159			if exts == nil && mi.extensionOffset.IsValid() {
160				exts = p.Apply(mi.extensionOffset).Extensions()
161				if *exts == nil {
162					*exts = make(map[int32]ExtensionField)
163				}
164			}
165			if exts == nil {
166				break
167			}
168			var o unmarshalOutput
169			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
170			if err != nil {
171				break
172			}
173			n = o.n
174			if !o.initialized {
175				initialized = false
176			}
177		}
178		if err != nil {
179			if err != errUnknown {
180				return out, err
181			}
182			n = protowire.ConsumeFieldValue(num, wtyp, b)
183			if n < 0 {
184				return out, errDecode
185			}
186			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
187				u := mi.mutableUnknownBytes(p)
188				*u = protowire.AppendTag(*u, num, wtyp)
189				*u = append(*u, b[:n]...)
190			}
191		}
192		b = b[n:]
193	}
194	if groupTag != 0 {
195		return out, errDecode
196	}
197	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
198		initialized = false
199	}
200	if initialized {
201		out.initialized = true
202	}
203	out.n = start - len(b)
204	return out, nil
205}
206
207func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
208	x := exts[int32(num)]
209	xt := x.Type()
210	if xt == nil {
211		var err error
212		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
213		if err != nil {
214			if err == protoregistry.NotFound {
215				return out, errUnknown
216			}
217			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
218		}
219	}
220	xi := getExtensionFieldInfo(xt)
221	if xi.funcs.unmarshal == nil {
222		return out, errUnknown
223	}
224	if flags.LazyUnmarshalExtensions {
225		if opts.IsDefault() && x.canLazy(xt) {
226			out, valid := skipExtension(b, xi, num, wtyp, opts)
227			switch valid {
228			case ValidationValid:
229				if out.initialized {
230					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
231					exts[int32(num)] = x
232					return out, nil
233				}
234			case ValidationInvalid:
235				return out, errDecode
236			case ValidationUnknown:
237			}
238		}
239	}
240	ival := x.Value()
241	if !ival.IsValid() && xi.unmarshalNeedsValue {
242		// Create a new message, list, or map value to fill in.
243		// For enums, create a prototype value to let the unmarshal func know the
244		// concrete type.
245		ival = xt.New()
246	}
247	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
248	if err != nil {
249		return out, err
250	}
251	if xi.funcs.isInit == nil {
252		out.initialized = true
253	}
254	x.Set(xt, v)
255	exts[int32(num)] = x
256	return out, nil
257}
258
259func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
260	if xi.validation.mi == nil {
261		return out, ValidationUnknown
262	}
263	xi.validation.mi.init()
264	switch xi.validation.typ {
265	case validationTypeMessage:
266		if wtyp != protowire.BytesType {
267			return out, ValidationUnknown
268		}
269		v, n := protowire.ConsumeBytes(b)
270		if n < 0 {
271			return out, ValidationUnknown
272		}
273		out, st := xi.validation.mi.validate(v, 0, opts)
274		out.n = n
275		return out, st
276	case validationTypeGroup:
277		if wtyp != protowire.StartGroupType {
278			return out, ValidationUnknown
279		}
280		out, st := xi.validation.mi.validate(b, num, opts)
281		return out, st
282	default:
283		return out, ValidationUnknown
284	}
285}
286