1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file is a simple protocol buffer encoder and decoder.
6//
7// A protocol message must implement the message interface:
8//   decoder() []decoder
9//   encode(*buffer)
10//
11// The decode method returns a slice indexed by field number that gives the
12// function to decode that field.
13// The encode method encodes its receiver into the given buffer.
14//
15// The two methods are simple enough to be implemented by hand rather than
16// by using a protocol compiler.
17//
18// See profile.go for examples of messages implementing this interface.
19//
20// There is no support for groups, message sets, or "has" bits.
21
22package profile
23
24import (
25	"errors"
26	"fmt"
27)
28
29type buffer struct {
30	field int
31	typ   int
32	u64   uint64
33	data  []byte
34	tmp   [16]byte
35}
36
37type decoder func(*buffer, message) error
38
39type message interface {
40	decoder() []decoder
41	encode(*buffer)
42}
43
44func marshal(m message) []byte {
45	var b buffer
46	m.encode(&b)
47	return b.data
48}
49
50func encodeVarint(b *buffer, x uint64) {
51	for x >= 128 {
52		b.data = append(b.data, byte(x)|0x80)
53		x >>= 7
54	}
55	b.data = append(b.data, byte(x))
56}
57
58func encodeLength(b *buffer, tag int, len int) {
59	encodeVarint(b, uint64(tag)<<3|2)
60	encodeVarint(b, uint64(len))
61}
62
63func encodeUint64(b *buffer, tag int, x uint64) {
64	// append varint to b.data
65	encodeVarint(b, uint64(tag)<<3|0)
66	encodeVarint(b, x)
67}
68
69func encodeUint64s(b *buffer, tag int, x []uint64) {
70	if len(x) > 2 {
71		// Use packed encoding
72		n1 := len(b.data)
73		for _, u := range x {
74			encodeVarint(b, u)
75		}
76		n2 := len(b.data)
77		encodeLength(b, tag, n2-n1)
78		n3 := len(b.data)
79		copy(b.tmp[:], b.data[n2:n3])
80		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
81		copy(b.data[n1:], b.tmp[:n3-n2])
82		return
83	}
84	for _, u := range x {
85		encodeUint64(b, tag, u)
86	}
87}
88
89func encodeUint64Opt(b *buffer, tag int, x uint64) {
90	if x == 0 {
91		return
92	}
93	encodeUint64(b, tag, x)
94}
95
96func encodeInt64(b *buffer, tag int, x int64) {
97	u := uint64(x)
98	encodeUint64(b, tag, u)
99}
100
101func encodeInt64Opt(b *buffer, tag int, x int64) {
102	if x == 0 {
103		return
104	}
105	encodeInt64(b, tag, x)
106}
107
108func encodeInt64s(b *buffer, tag int, x []int64) {
109	if len(x) > 2 {
110		// Use packed encoding
111		n1 := len(b.data)
112		for _, u := range x {
113			encodeVarint(b, uint64(u))
114		}
115		n2 := len(b.data)
116		encodeLength(b, tag, n2-n1)
117		n3 := len(b.data)
118		copy(b.tmp[:], b.data[n2:n3])
119		copy(b.data[n1+(n3-n2):], b.data[n1:n2])
120		copy(b.data[n1:], b.tmp[:n3-n2])
121		return
122	}
123	for _, u := range x {
124		encodeInt64(b, tag, u)
125	}
126}
127
128func encodeString(b *buffer, tag int, x string) {
129	encodeLength(b, tag, len(x))
130	b.data = append(b.data, x...)
131}
132
133func encodeStrings(b *buffer, tag int, x []string) {
134	for _, s := range x {
135		encodeString(b, tag, s)
136	}
137}
138
139func encodeBool(b *buffer, tag int, x bool) {
140	if x {
141		encodeUint64(b, tag, 1)
142	} else {
143		encodeUint64(b, tag, 0)
144	}
145}
146
147func encodeBoolOpt(b *buffer, tag int, x bool) {
148	if !x {
149		return
150	}
151	encodeBool(b, tag, x)
152}
153
154func encodeMessage(b *buffer, tag int, m message) {
155	n1 := len(b.data)
156	m.encode(b)
157	n2 := len(b.data)
158	encodeLength(b, tag, n2-n1)
159	n3 := len(b.data)
160	copy(b.tmp[:], b.data[n2:n3])
161	copy(b.data[n1+(n3-n2):], b.data[n1:n2])
162	copy(b.data[n1:], b.tmp[:n3-n2])
163}
164
165func unmarshal(data []byte, m message) (err error) {
166	b := buffer{data: data, typ: 2}
167	return decodeMessage(&b, m)
168}
169
170func le64(p []byte) uint64 {
171	return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
172}
173
174func le32(p []byte) uint32 {
175	return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
176}
177
178func decodeVarint(data []byte) (uint64, []byte, error) {
179	var i int
180	var u uint64
181	for i = 0; ; i++ {
182		if i >= 10 || i >= len(data) {
183			return 0, nil, errors.New("bad varint")
184		}
185		u |= uint64(data[i]&0x7F) << uint(7*i)
186		if data[i]&0x80 == 0 {
187			return u, data[i+1:], nil
188		}
189	}
190}
191
192func decodeField(b *buffer, data []byte) ([]byte, error) {
193	x, data, err := decodeVarint(data)
194	if err != nil {
195		return nil, err
196	}
197	b.field = int(x >> 3)
198	b.typ = int(x & 7)
199	b.data = nil
200	b.u64 = 0
201	switch b.typ {
202	case 0:
203		b.u64, data, err = decodeVarint(data)
204		if err != nil {
205			return nil, err
206		}
207	case 1:
208		if len(data) < 8 {
209			return nil, errors.New("not enough data")
210		}
211		b.u64 = le64(data[:8])
212		data = data[8:]
213	case 2:
214		var n uint64
215		n, data, err = decodeVarint(data)
216		if err != nil {
217			return nil, err
218		}
219		if n > uint64(len(data)) {
220			return nil, errors.New("too much data")
221		}
222		b.data = data[:n]
223		data = data[n:]
224	case 5:
225		if len(data) < 4 {
226			return nil, errors.New("not enough data")
227		}
228		b.u64 = uint64(le32(data[:4]))
229		data = data[4:]
230	default:
231		return nil, fmt.Errorf("unknown wire type: %d", b.typ)
232	}
233
234	return data, nil
235}
236
237func checkType(b *buffer, typ int) error {
238	if b.typ != typ {
239		return errors.New("type mismatch")
240	}
241	return nil
242}
243
244func decodeMessage(b *buffer, m message) error {
245	if err := checkType(b, 2); err != nil {
246		return err
247	}
248	dec := m.decoder()
249	data := b.data
250	for len(data) > 0 {
251		// pull varint field# + type
252		var err error
253		data, err = decodeField(b, data)
254		if err != nil {
255			return err
256		}
257		if b.field >= len(dec) || dec[b.field] == nil {
258			continue
259		}
260		if err := dec[b.field](b, m); err != nil {
261			return err
262		}
263	}
264	return nil
265}
266
267func decodeInt64(b *buffer, x *int64) error {
268	if err := checkType(b, 0); err != nil {
269		return err
270	}
271	*x = int64(b.u64)
272	return nil
273}
274
275func decodeInt64s(b *buffer, x *[]int64) error {
276	if b.typ == 2 {
277		// Packed encoding
278		data := b.data
279		for len(data) > 0 {
280			var u uint64
281			var err error
282
283			if u, data, err = decodeVarint(data); err != nil {
284				return err
285			}
286			*x = append(*x, int64(u))
287		}
288		return nil
289	}
290	var i int64
291	if err := decodeInt64(b, &i); err != nil {
292		return err
293	}
294	*x = append(*x, i)
295	return nil
296}
297
298func decodeUint64(b *buffer, x *uint64) error {
299	if err := checkType(b, 0); err != nil {
300		return err
301	}
302	*x = b.u64
303	return nil
304}
305
306func decodeUint64s(b *buffer, x *[]uint64) error {
307	if b.typ == 2 {
308		data := b.data
309		// Packed encoding
310		for len(data) > 0 {
311			var u uint64
312			var err error
313
314			if u, data, err = decodeVarint(data); err != nil {
315				return err
316			}
317			*x = append(*x, u)
318		}
319		return nil
320	}
321	var u uint64
322	if err := decodeUint64(b, &u); err != nil {
323		return err
324	}
325	*x = append(*x, u)
326	return nil
327}
328
329func decodeString(b *buffer, x *string) error {
330	if err := checkType(b, 2); err != nil {
331		return err
332	}
333	*x = string(b.data)
334	return nil
335}
336
337func decodeStrings(b *buffer, x *[]string) error {
338	var s string
339	if err := decodeString(b, &s); err != nil {
340		return err
341	}
342	*x = append(*x, s)
343	return nil
344}
345
346func decodeBool(b *buffer, x *bool) error {
347	if err := checkType(b, 0); err != nil {
348		return err
349	}
350	if int64(b.u64) == 0 {
351		*x = false
352	} else {
353		*x = true
354	}
355	return nil
356}
357