1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// This file is a simple protocol buffer encoder and decoder. 6// 7// A protocol message must implement the message interface: 8// decoder() []decoder 9// encode(*buffer) 10// 11// The decode method returns a slice indexed by field number that gives the 12// function to decode that field. 13// The encode method encodes its receiver into the given buffer. 14// 15// The two methods are simple enough to be implemented by hand rather than 16// by using a protocol compiler. 17// 18// See profile.go for examples of messages implementing this interface. 19// 20// There is no support for groups, message sets, or "has" bits. 21 22package profile 23 24import ( 25 "errors" 26 "fmt" 27) 28 29type buffer struct { 30 field int 31 typ int 32 u64 uint64 33 data []byte 34 tmp [16]byte 35} 36 37type decoder func(*buffer, message) error 38 39type message interface { 40 decoder() []decoder 41 encode(*buffer) 42} 43 44func marshal(m message) []byte { 45 var b buffer 46 m.encode(&b) 47 return b.data 48} 49 50func encodeVarint(b *buffer, x uint64) { 51 for x >= 128 { 52 b.data = append(b.data, byte(x)|0x80) 53 x >>= 7 54 } 55 b.data = append(b.data, byte(x)) 56} 57 58func encodeLength(b *buffer, tag int, len int) { 59 encodeVarint(b, uint64(tag)<<3|2) 60 encodeVarint(b, uint64(len)) 61} 62 63func encodeUint64(b *buffer, tag int, x uint64) { 64 // append varint to b.data 65 encodeVarint(b, uint64(tag)<<3|0) 66 encodeVarint(b, x) 67} 68 69func encodeUint64s(b *buffer, tag int, x []uint64) { 70 if len(x) > 2 { 71 // Use packed encoding 72 n1 := len(b.data) 73 for _, u := range x { 74 encodeVarint(b, u) 75 } 76 n2 := len(b.data) 77 encodeLength(b, tag, n2-n1) 78 n3 := len(b.data) 79 copy(b.tmp[:], b.data[n2:n3]) 80 copy(b.data[n1+(n3-n2):], b.data[n1:n2]) 81 copy(b.data[n1:], b.tmp[:n3-n2]) 82 return 83 } 84 for _, u := range x { 85 encodeUint64(b, tag, u) 86 } 87} 88 89func encodeUint64Opt(b *buffer, tag int, x uint64) { 90 if x == 0 { 91 return 92 } 93 encodeUint64(b, tag, x) 94} 95 96func encodeInt64(b *buffer, tag int, x int64) { 97 u := uint64(x) 98 encodeUint64(b, tag, u) 99} 100 101func encodeInt64Opt(b *buffer, tag int, x int64) { 102 if x == 0 { 103 return 104 } 105 encodeInt64(b, tag, x) 106} 107 108func encodeInt64s(b *buffer, tag int, x []int64) { 109 if len(x) > 2 { 110 // Use packed encoding 111 n1 := len(b.data) 112 for _, u := range x { 113 encodeVarint(b, uint64(u)) 114 } 115 n2 := len(b.data) 116 encodeLength(b, tag, n2-n1) 117 n3 := len(b.data) 118 copy(b.tmp[:], b.data[n2:n3]) 119 copy(b.data[n1+(n3-n2):], b.data[n1:n2]) 120 copy(b.data[n1:], b.tmp[:n3-n2]) 121 return 122 } 123 for _, u := range x { 124 encodeInt64(b, tag, u) 125 } 126} 127 128func encodeString(b *buffer, tag int, x string) { 129 encodeLength(b, tag, len(x)) 130 b.data = append(b.data, x...) 131} 132 133func encodeStrings(b *buffer, tag int, x []string) { 134 for _, s := range x { 135 encodeString(b, tag, s) 136 } 137} 138 139func encodeBool(b *buffer, tag int, x bool) { 140 if x { 141 encodeUint64(b, tag, 1) 142 } else { 143 encodeUint64(b, tag, 0) 144 } 145} 146 147func encodeBoolOpt(b *buffer, tag int, x bool) { 148 if !x { 149 return 150 } 151 encodeBool(b, tag, x) 152} 153 154func encodeMessage(b *buffer, tag int, m message) { 155 n1 := len(b.data) 156 m.encode(b) 157 n2 := len(b.data) 158 encodeLength(b, tag, n2-n1) 159 n3 := len(b.data) 160 copy(b.tmp[:], b.data[n2:n3]) 161 copy(b.data[n1+(n3-n2):], b.data[n1:n2]) 162 copy(b.data[n1:], b.tmp[:n3-n2]) 163} 164 165func unmarshal(data []byte, m message) (err error) { 166 b := buffer{data: data, typ: 2} 167 return decodeMessage(&b, m) 168} 169 170func le64(p []byte) uint64 { 171 return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56 172} 173 174func le32(p []byte) uint32 { 175 return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 176} 177 178func decodeVarint(data []byte) (uint64, []byte, error) { 179 var i int 180 var u uint64 181 for i = 0; ; i++ { 182 if i >= 10 || i >= len(data) { 183 return 0, nil, errors.New("bad varint") 184 } 185 u |= uint64(data[i]&0x7F) << uint(7*i) 186 if data[i]&0x80 == 0 { 187 return u, data[i+1:], nil 188 } 189 } 190} 191 192func decodeField(b *buffer, data []byte) ([]byte, error) { 193 x, data, err := decodeVarint(data) 194 if err != nil { 195 return nil, err 196 } 197 b.field = int(x >> 3) 198 b.typ = int(x & 7) 199 b.data = nil 200 b.u64 = 0 201 switch b.typ { 202 case 0: 203 b.u64, data, err = decodeVarint(data) 204 if err != nil { 205 return nil, err 206 } 207 case 1: 208 if len(data) < 8 { 209 return nil, errors.New("not enough data") 210 } 211 b.u64 = le64(data[:8]) 212 data = data[8:] 213 case 2: 214 var n uint64 215 n, data, err = decodeVarint(data) 216 if err != nil { 217 return nil, err 218 } 219 if n > uint64(len(data)) { 220 return nil, errors.New("too much data") 221 } 222 b.data = data[:n] 223 data = data[n:] 224 case 5: 225 if len(data) < 4 { 226 return nil, errors.New("not enough data") 227 } 228 b.u64 = uint64(le32(data[:4])) 229 data = data[4:] 230 default: 231 return nil, fmt.Errorf("unknown wire type: %d", b.typ) 232 } 233 234 return data, nil 235} 236 237func checkType(b *buffer, typ int) error { 238 if b.typ != typ { 239 return errors.New("type mismatch") 240 } 241 return nil 242} 243 244func decodeMessage(b *buffer, m message) error { 245 if err := checkType(b, 2); err != nil { 246 return err 247 } 248 dec := m.decoder() 249 data := b.data 250 for len(data) > 0 { 251 // pull varint field# + type 252 var err error 253 data, err = decodeField(b, data) 254 if err != nil { 255 return err 256 } 257 if b.field >= len(dec) || dec[b.field] == nil { 258 continue 259 } 260 if err := dec[b.field](b, m); err != nil { 261 return err 262 } 263 } 264 return nil 265} 266 267func decodeInt64(b *buffer, x *int64) error { 268 if err := checkType(b, 0); err != nil { 269 return err 270 } 271 *x = int64(b.u64) 272 return nil 273} 274 275func decodeInt64s(b *buffer, x *[]int64) error { 276 if b.typ == 2 { 277 // Packed encoding 278 data := b.data 279 for len(data) > 0 { 280 var u uint64 281 var err error 282 283 if u, data, err = decodeVarint(data); err != nil { 284 return err 285 } 286 *x = append(*x, int64(u)) 287 } 288 return nil 289 } 290 var i int64 291 if err := decodeInt64(b, &i); err != nil { 292 return err 293 } 294 *x = append(*x, i) 295 return nil 296} 297 298func decodeUint64(b *buffer, x *uint64) error { 299 if err := checkType(b, 0); err != nil { 300 return err 301 } 302 *x = b.u64 303 return nil 304} 305 306func decodeUint64s(b *buffer, x *[]uint64) error { 307 if b.typ == 2 { 308 data := b.data 309 // Packed encoding 310 for len(data) > 0 { 311 var u uint64 312 var err error 313 314 if u, data, err = decodeVarint(data); err != nil { 315 return err 316 } 317 *x = append(*x, u) 318 } 319 return nil 320 } 321 var u uint64 322 if err := decodeUint64(b, &u); err != nil { 323 return err 324 } 325 *x = append(*x, u) 326 return nil 327} 328 329func decodeString(b *buffer, x *string) error { 330 if err := checkType(b, 2); err != nil { 331 return err 332 } 333 *x = string(b.data) 334 return nil 335} 336 337func decodeStrings(b *buffer, x *[]string) error { 338 var s string 339 if err := decodeString(b, &s); err != nil { 340 return err 341 } 342 *x = append(*x, s) 343 return nil 344} 345 346func decodeBool(b *buffer, x *bool) error { 347 if err := checkType(b, 0); err != nil { 348 return err 349 } 350 if int64(b.u64) == 0 { 351 *x = false 352 } else { 353 *x = true 354 } 355 return nil 356} 357