1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5//go:generate go run . -execute 6 7package main 8 9import ( 10 "bytes" 11 "flag" 12 "fmt" 13 "go/format" 14 "io/ioutil" 15 "os" 16 "os/exec" 17 "path" 18 "path/filepath" 19 "regexp" 20 "strconv" 21 "strings" 22 "text/template" 23) 24 25var ( 26 run bool 27 outfile string 28 repoRoot string 29) 30 31func main() { 32 flag.BoolVar(&run, "execute", false, "Write generated files to destination.") 33 flag.StringVar(&outfile, "outfile", "", "Write this specific file to stdout.") 34 flag.Parse() 35 36 // Determine repository root path. 37 if outfile == "" { 38 out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput() 39 check(err) 40 repoRoot = strings.TrimSpace(string(out)) 41 chdirRoot() 42 } 43 44 writeSource("internal/filedesc/desc_list_gen.go", generateDescListTypes()) 45 writeSource("internal/impl/codec_gen.go", generateImplCodec()) 46 writeSource("internal/impl/message_reflect_gen.go", generateImplMessage()) 47 writeSource("internal/impl/merge_gen.go", generateImplMerge()) 48 writeSource("proto/decode_gen.go", generateProtoDecode()) 49 writeSource("proto/encode_gen.go", generateProtoEncode()) 50 writeSource("proto/size_gen.go", generateProtoSize()) 51} 52 53// chdirRoot changes the working directory to the repository root. 54func chdirRoot() { 55 out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput() 56 check(err) 57 check(os.Chdir(strings.TrimSpace(string(out)))) 58} 59 60// Expr is a single line Go expression. 61type Expr string 62 63type DescriptorType string 64 65const ( 66 MessageDesc DescriptorType = "Message" 67 FieldDesc DescriptorType = "Field" 68 OneofDesc DescriptorType = "Oneof" 69 ExtensionDesc DescriptorType = "Extension" 70 EnumDesc DescriptorType = "Enum" 71 EnumValueDesc DescriptorType = "EnumValue" 72 ServiceDesc DescriptorType = "Service" 73 MethodDesc DescriptorType = "Method" 74) 75 76func (d DescriptorType) Expr() Expr { 77 return "protoreflect." + Expr(d) + "Descriptor" 78} 79func (d DescriptorType) NumberExpr() Expr { 80 switch d { 81 case FieldDesc: 82 return "protoreflect.FieldNumber" 83 case EnumValueDesc: 84 return "protoreflect.EnumNumber" 85 default: 86 return "" 87 } 88} 89 90func generateDescListTypes() string { 91 return mustExecute(descListTypesTemplate, []DescriptorType{ 92 EnumDesc, EnumValueDesc, MessageDesc, FieldDesc, OneofDesc, ExtensionDesc, ServiceDesc, MethodDesc, 93 }) 94} 95 96var descListTypesTemplate = template.Must(template.New("").Parse(` 97 {{- range .}} 98 {{$nameList := (printf "%ss" .)}} {{/* e.g., "Messages" */}} 99 {{$nameDesc := (printf "%s" .)}} {{/* e.g., "Message" */}} 100 101 type {{$nameList}} struct { 102 List []{{$nameDesc}} 103 once sync.Once 104 byName map[protoreflect.Name]*{{$nameDesc}} // protected by once 105 {{- if (eq . "Field")}} 106 byJSON map[string]*{{$nameDesc}} // protected by once 107 byText map[string]*{{$nameDesc}} // protected by once 108 {{- end}} 109 {{- if .NumberExpr}} 110 byNum map[{{.NumberExpr}}]*{{$nameDesc}} // protected by once 111 {{- end}} 112 } 113 114 func (p *{{$nameList}}) Len() int { 115 return len(p.List) 116 } 117 func (p *{{$nameList}}) Get(i int) {{.Expr}} { 118 return &p.List[i] 119 } 120 func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} { 121 if d := p.lazyInit().byName[s]; d != nil { 122 return d 123 } 124 return nil 125 } 126 {{- if (eq . "Field")}} 127 func (p *{{$nameList}}) ByJSONName(s string) {{.Expr}} { 128 if d := p.lazyInit().byJSON[s]; d != nil { 129 return d 130 } 131 return nil 132 } 133 func (p *{{$nameList}}) ByTextName(s string) {{.Expr}} { 134 if d := p.lazyInit().byText[s]; d != nil { 135 return d 136 } 137 return nil 138 } 139 {{- end}} 140 {{- if .NumberExpr}} 141 func (p *{{$nameList}}) ByNumber(n {{.NumberExpr}}) {{.Expr}} { 142 if d := p.lazyInit().byNum[n]; d != nil { 143 return d 144 } 145 return nil 146 } 147 {{- end}} 148 func (p *{{$nameList}}) Format(s fmt.State, r rune) { 149 descfmt.FormatList(s, r, p) 150 } 151 func (p *{{$nameList}}) ProtoInternal(pragma.DoNotImplement) {} 152 func (p *{{$nameList}}) lazyInit() *{{$nameList}} { 153 p.once.Do(func() { 154 if len(p.List) > 0 { 155 p.byName = make(map[protoreflect.Name]*{{$nameDesc}}, len(p.List)) 156 {{- if (eq . "Field")}} 157 p.byJSON = make(map[string]*{{$nameDesc}}, len(p.List)) 158 p.byText = make(map[string]*{{$nameDesc}}, len(p.List)) 159 {{- end}} 160 {{- if .NumberExpr}} 161 p.byNum = make(map[{{.NumberExpr}}]*{{$nameDesc}}, len(p.List)) 162 {{- end}} 163 for i := range p.List { 164 d := &p.List[i] 165 if _, ok := p.byName[d.Name()]; !ok { 166 p.byName[d.Name()] = d 167 } 168 {{- if (eq . "Field")}} 169 if _, ok := p.byJSON[d.JSONName()]; !ok { 170 p.byJSON[d.JSONName()] = d 171 } 172 if _, ok := p.byText[d.TextName()]; !ok { 173 p.byText[d.TextName()] = d 174 } 175 {{- end}} 176 {{- if .NumberExpr}} 177 if _, ok := p.byNum[d.Number()]; !ok { 178 p.byNum[d.Number()] = d 179 } 180 {{- end}} 181 } 182 } 183 }) 184 return p 185 } 186 {{- end}} 187`)) 188 189func mustExecute(t *template.Template, data interface{}) string { 190 var b bytes.Buffer 191 if err := t.Execute(&b, data); err != nil { 192 panic(err) 193 } 194 return b.String() 195} 196 197func writeSource(file, src string) { 198 // Crude but effective way to detect used imports. 199 var imports []string 200 for _, pkg := range []string{ 201 "fmt", 202 "math", 203 "reflect", 204 "sync", 205 "unicode/utf8", 206 "", 207 "google.golang.org/protobuf/internal/descfmt", 208 "google.golang.org/protobuf/encoding/protowire", 209 "google.golang.org/protobuf/internal/errors", 210 "google.golang.org/protobuf/internal/strs", 211 "google.golang.org/protobuf/internal/pragma", 212 "google.golang.org/protobuf/reflect/protoreflect", 213 "google.golang.org/protobuf/runtime/protoiface", 214 } { 215 if pkg == "" { 216 imports = append(imports, "") // blank line between stdlib and proto packages 217 } else if regexp.MustCompile(`[^\pL_0-9]` + path.Base(pkg) + `\.`).MatchString(src) { 218 imports = append(imports, strconv.Quote(pkg)) 219 } 220 } 221 222 s := strings.Join([]string{ 223 "// Copyright 2018 The Go Authors. All rights reserved.", 224 "// Use of this source code is governed by a BSD-style", 225 "// license that can be found in the LICENSE file.", 226 "", 227 "// Code generated by generate-types. DO NOT EDIT.", 228 "", 229 "package " + path.Base(path.Dir(path.Join("proto", file))), 230 "", 231 "import (" + strings.Join(imports, "\n") + ")", 232 "", 233 src, 234 }, "\n") 235 b, err := format.Source([]byte(s)) 236 if err != nil { 237 // Just print the error and output the unformatted file for examination. 238 fmt.Fprintf(os.Stderr, "%v:%v\n", file, err) 239 b = []byte(s) 240 } 241 242 if outfile != "" { 243 if outfile == file { 244 os.Stdout.Write(b) 245 } 246 return 247 } 248 249 absFile := filepath.Join(repoRoot, file) 250 if run { 251 prev, _ := ioutil.ReadFile(absFile) 252 if !bytes.Equal(b, prev) { 253 fmt.Println("#", file) 254 check(ioutil.WriteFile(absFile, b, 0664)) 255 } 256 } else { 257 check(ioutil.WriteFile(absFile+".tmp", b, 0664)) 258 defer os.Remove(absFile + ".tmp") 259 260 cmd := exec.Command("diff", file, file+".tmp", "-N", "-u") 261 cmd.Dir = repoRoot 262 cmd.Stdout = os.Stdout 263 cmd.Run() 264 } 265} 266 267func check(err error) { 268 if err != nil { 269 panic(err) 270 } 271} 272