xref: /aosp_15_r20/external/golang-protobuf/internal/cmd/generate-types/main.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:generate go run . -execute
6
7package main
8
9import (
10	"bytes"
11	"flag"
12	"fmt"
13	"go/format"
14	"io/ioutil"
15	"os"
16	"os/exec"
17	"path"
18	"path/filepath"
19	"regexp"
20	"strconv"
21	"strings"
22	"text/template"
23)
24
25var (
26	run      bool
27	outfile  string
28	repoRoot string
29)
30
31func main() {
32	flag.BoolVar(&run, "execute", false, "Write generated files to destination.")
33	flag.StringVar(&outfile, "outfile", "", "Write this specific file to stdout.")
34	flag.Parse()
35
36	// Determine repository root path.
37	if outfile == "" {
38		out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
39		check(err)
40		repoRoot = strings.TrimSpace(string(out))
41		chdirRoot()
42	}
43
44	writeSource("internal/filedesc/desc_list_gen.go", generateDescListTypes())
45	writeSource("internal/impl/codec_gen.go", generateImplCodec())
46	writeSource("internal/impl/message_reflect_gen.go", generateImplMessage())
47	writeSource("internal/impl/merge_gen.go", generateImplMerge())
48	writeSource("proto/decode_gen.go", generateProtoDecode())
49	writeSource("proto/encode_gen.go", generateProtoEncode())
50	writeSource("proto/size_gen.go", generateProtoSize())
51}
52
53// chdirRoot changes the working directory to the repository root.
54func chdirRoot() {
55	out, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput()
56	check(err)
57	check(os.Chdir(strings.TrimSpace(string(out))))
58}
59
60// Expr is a single line Go expression.
61type Expr string
62
63type DescriptorType string
64
65const (
66	MessageDesc   DescriptorType = "Message"
67	FieldDesc     DescriptorType = "Field"
68	OneofDesc     DescriptorType = "Oneof"
69	ExtensionDesc DescriptorType = "Extension"
70	EnumDesc      DescriptorType = "Enum"
71	EnumValueDesc DescriptorType = "EnumValue"
72	ServiceDesc   DescriptorType = "Service"
73	MethodDesc    DescriptorType = "Method"
74)
75
76func (d DescriptorType) Expr() Expr {
77	return "protoreflect." + Expr(d) + "Descriptor"
78}
79func (d DescriptorType) NumberExpr() Expr {
80	switch d {
81	case FieldDesc:
82		return "protoreflect.FieldNumber"
83	case EnumValueDesc:
84		return "protoreflect.EnumNumber"
85	default:
86		return ""
87	}
88}
89
90func generateDescListTypes() string {
91	return mustExecute(descListTypesTemplate, []DescriptorType{
92		EnumDesc, EnumValueDesc, MessageDesc, FieldDesc, OneofDesc, ExtensionDesc, ServiceDesc, MethodDesc,
93	})
94}
95
96var descListTypesTemplate = template.Must(template.New("").Parse(`
97	{{- range .}}
98	{{$nameList := (printf "%ss" .)}} {{/* e.g., "Messages" */}}
99	{{$nameDesc := (printf "%s"  .)}} {{/* e.g., "Message" */}}
100
101	type {{$nameList}} struct {
102		List   []{{$nameDesc}}
103		once   sync.Once
104		byName map[protoreflect.Name]*{{$nameDesc}} // protected by once
105		{{- if (eq . "Field")}}
106		byJSON map[string]*{{$nameDesc}}            // protected by once
107		byText map[string]*{{$nameDesc}}            // protected by once
108		{{- end}}
109		{{- if .NumberExpr}}
110		byNum  map[{{.NumberExpr}}]*{{$nameDesc}}   // protected by once
111		{{- end}}
112	}
113
114	func (p *{{$nameList}}) Len() int {
115		return len(p.List)
116	}
117	func (p *{{$nameList}}) Get(i int) {{.Expr}} {
118		return &p.List[i]
119	}
120	func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
121		if d := p.lazyInit().byName[s]; d != nil {
122			return d
123		}
124		return nil
125	}
126	{{- if (eq . "Field")}}
127	func (p *{{$nameList}}) ByJSONName(s string) {{.Expr}} {
128		if d := p.lazyInit().byJSON[s]; d != nil {
129			return d
130		}
131		return nil
132	}
133	func (p *{{$nameList}}) ByTextName(s string) {{.Expr}} {
134		if d := p.lazyInit().byText[s]; d != nil {
135			return d
136		}
137		return nil
138	}
139	{{- end}}
140	{{- if .NumberExpr}}
141	func (p *{{$nameList}}) ByNumber(n {{.NumberExpr}}) {{.Expr}} {
142		if d := p.lazyInit().byNum[n]; d != nil {
143			return d
144		}
145		return nil
146	}
147	{{- end}}
148	func (p *{{$nameList}}) Format(s fmt.State, r rune) {
149		descfmt.FormatList(s, r, p)
150	}
151	func (p *{{$nameList}}) ProtoInternal(pragma.DoNotImplement) {}
152	func (p *{{$nameList}}) lazyInit() *{{$nameList}} {
153		p.once.Do(func() {
154			if len(p.List) > 0 {
155				p.byName = make(map[protoreflect.Name]*{{$nameDesc}}, len(p.List))
156				{{- if (eq . "Field")}}
157				p.byJSON = make(map[string]*{{$nameDesc}}, len(p.List))
158				p.byText = make(map[string]*{{$nameDesc}}, len(p.List))
159				{{- end}}
160				{{- if .NumberExpr}}
161				p.byNum = make(map[{{.NumberExpr}}]*{{$nameDesc}}, len(p.List))
162				{{- end}}
163				for i := range p.List {
164					d := &p.List[i]
165					if _, ok := p.byName[d.Name()]; !ok {
166						p.byName[d.Name()] = d
167					}
168					{{- if (eq . "Field")}}
169					if _, ok := p.byJSON[d.JSONName()]; !ok {
170						p.byJSON[d.JSONName()] = d
171					}
172					if _, ok := p.byText[d.TextName()]; !ok {
173						p.byText[d.TextName()] = d
174					}
175					{{- end}}
176					{{- if .NumberExpr}}
177					if _, ok := p.byNum[d.Number()]; !ok {
178						p.byNum[d.Number()] = d
179					}
180					{{- end}}
181				}
182			}
183		})
184		return p
185	}
186	{{- end}}
187`))
188
189func mustExecute(t *template.Template, data interface{}) string {
190	var b bytes.Buffer
191	if err := t.Execute(&b, data); err != nil {
192		panic(err)
193	}
194	return b.String()
195}
196
197func writeSource(file, src string) {
198	// Crude but effective way to detect used imports.
199	var imports []string
200	for _, pkg := range []string{
201		"fmt",
202		"math",
203		"reflect",
204		"sync",
205		"unicode/utf8",
206		"",
207		"google.golang.org/protobuf/internal/descfmt",
208		"google.golang.org/protobuf/encoding/protowire",
209		"google.golang.org/protobuf/internal/errors",
210		"google.golang.org/protobuf/internal/strs",
211		"google.golang.org/protobuf/internal/pragma",
212		"google.golang.org/protobuf/reflect/protoreflect",
213		"google.golang.org/protobuf/runtime/protoiface",
214	} {
215		if pkg == "" {
216			imports = append(imports, "") // blank line between stdlib and proto packages
217		} else if regexp.MustCompile(`[^\pL_0-9]` + path.Base(pkg) + `\.`).MatchString(src) {
218			imports = append(imports, strconv.Quote(pkg))
219		}
220	}
221
222	s := strings.Join([]string{
223		"// Copyright 2018 The Go Authors. All rights reserved.",
224		"// Use of this source code is governed by a BSD-style",
225		"// license that can be found in the LICENSE file.",
226		"",
227		"// Code generated by generate-types. DO NOT EDIT.",
228		"",
229		"package " + path.Base(path.Dir(path.Join("proto", file))),
230		"",
231		"import (" + strings.Join(imports, "\n") + ")",
232		"",
233		src,
234	}, "\n")
235	b, err := format.Source([]byte(s))
236	if err != nil {
237		// Just print the error and output the unformatted file for examination.
238		fmt.Fprintf(os.Stderr, "%v:%v\n", file, err)
239		b = []byte(s)
240	}
241
242	if outfile != "" {
243		if outfile == file {
244			os.Stdout.Write(b)
245		}
246		return
247	}
248
249	absFile := filepath.Join(repoRoot, file)
250	if run {
251		prev, _ := ioutil.ReadFile(absFile)
252		if !bytes.Equal(b, prev) {
253			fmt.Println("#", file)
254			check(ioutil.WriteFile(absFile, b, 0664))
255		}
256	} else {
257		check(ioutil.WriteFile(absFile+".tmp", b, 0664))
258		defer os.Remove(absFile + ".tmp")
259
260		cmd := exec.Command("diff", file, file+".tmp", "-N", "-u")
261		cmd.Dir = repoRoot
262		cmd.Stdout = os.Stdout
263		cmd.Run()
264	}
265}
266
267func check(err error) {
268	if err != nil {
269		panic(err)
270	}
271}
272