xref: /aosp_15_r20/external/golang-protobuf/cmd/protoc-gen-go/internal_gengo/main.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package internal_gengo is internal to the protobuf module.
6package internal_gengo
7
8import (
9	"fmt"
10	"go/ast"
11	"go/parser"
12	"go/token"
13	"math"
14	"strconv"
15	"strings"
16	"unicode"
17	"unicode/utf8"
18
19	"google.golang.org/protobuf/compiler/protogen"
20	"google.golang.org/protobuf/internal/encoding/tag"
21	"google.golang.org/protobuf/internal/genid"
22	"google.golang.org/protobuf/internal/version"
23	"google.golang.org/protobuf/reflect/protoreflect"
24	"google.golang.org/protobuf/runtime/protoimpl"
25
26	"google.golang.org/protobuf/types/descriptorpb"
27	"google.golang.org/protobuf/types/pluginpb"
28)
29
30// SupportedFeatures reports the set of supported protobuf language features.
31var SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
32
33// GenerateVersionMarkers specifies whether to generate version markers.
34var GenerateVersionMarkers = true
35
36// Standard library dependencies.
37const (
38	base64Package  = protogen.GoImportPath("encoding/base64")
39	mathPackage    = protogen.GoImportPath("math")
40	reflectPackage = protogen.GoImportPath("reflect")
41	sortPackage    = protogen.GoImportPath("sort")
42	stringsPackage = protogen.GoImportPath("strings")
43	syncPackage    = protogen.GoImportPath("sync")
44	timePackage    = protogen.GoImportPath("time")
45	utf8Package    = protogen.GoImportPath("unicode/utf8")
46)
47
48// Protobuf library dependencies.
49//
50// These are declared as an interface type so that they can be more easily
51// patched to support unique build environments that impose restrictions
52// on the dependencies of generated source code.
53var (
54	protoPackage         goImportPath = protogen.GoImportPath("google.golang.org/protobuf/proto")
55	protoifacePackage    goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoiface")
56	protoimplPackage     goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
57	protojsonPackage     goImportPath = protogen.GoImportPath("google.golang.org/protobuf/encoding/protojson")
58	protoreflectPackage  goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect")
59	protoregistryPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoregistry")
60)
61
62type goImportPath interface {
63	String() string
64	Ident(string) protogen.GoIdent
65}
66
67// GenerateFile generates the contents of a .pb.go file.
68func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
69	filename := file.GeneratedFilenamePrefix + ".pb.go"
70	g := gen.NewGeneratedFile(filename, file.GoImportPath)
71	f := newFileInfo(file)
72
73	genStandaloneComments(g, f, int32(genid.FileDescriptorProto_Syntax_field_number))
74	genGeneratedHeader(gen, g, f)
75	genStandaloneComments(g, f, int32(genid.FileDescriptorProto_Package_field_number))
76
77	packageDoc := genPackageKnownComment(f)
78	g.P(packageDoc, "package ", f.GoPackageName)
79	g.P()
80
81	// Emit a static check that enforces a minimum version of the proto package.
82	if GenerateVersionMarkers {
83		g.P("const (")
84		g.P("// Verify that this generated code is sufficiently up-to-date.")
85		g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
86		g.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
87		g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
88		g.P(")")
89		g.P()
90	}
91
92	for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
93		genImport(gen, g, f, imps.Get(i))
94	}
95	for _, enum := range f.allEnums {
96		genEnum(g, f, enum)
97	}
98	for _, message := range f.allMessages {
99		genMessage(g, f, message)
100	}
101	genExtensions(g, f)
102
103	genReflectFileDescriptor(gen, g, f)
104
105	return g
106}
107
108// genStandaloneComments prints all leading comments for a FileDescriptorProto
109// location identified by the field number n.
110func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) {
111	loc := f.Desc.SourceLocations().ByPath(protoreflect.SourcePath{n})
112	for _, s := range loc.LeadingDetachedComments {
113		g.P(protogen.Comments(s))
114		g.P()
115	}
116	if s := loc.LeadingComments; s != "" {
117		g.P(protogen.Comments(s))
118		g.P()
119	}
120}
121
122func genGeneratedHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) {
123	g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
124
125	if GenerateVersionMarkers {
126		g.P("// versions:")
127		protocGenGoVersion := version.String()
128		protocVersion := "(unknown)"
129		if v := gen.Request.GetCompilerVersion(); v != nil {
130			protocVersion = fmt.Sprintf("v%v.%v.%v", v.GetMajor(), v.GetMinor(), v.GetPatch())
131			if s := v.GetSuffix(); s != "" {
132				protocVersion += "-" + s
133			}
134		}
135		g.P("// \tprotoc-gen-go ", protocGenGoVersion)
136		g.P("// \tprotoc        ", protocVersion)
137	}
138
139	if f.Proto.GetOptions().GetDeprecated() {
140		g.P("// ", f.Desc.Path(), " is a deprecated file.")
141	} else {
142		g.P("// source: ", f.Desc.Path())
143	}
144	g.P()
145}
146
147func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp protoreflect.FileImport) {
148	impFile, ok := gen.FilesByPath[imp.Path()]
149	if !ok {
150		return
151	}
152	if impFile.GoImportPath == f.GoImportPath {
153		// Don't generate imports or aliases for types in the same Go package.
154		return
155	}
156	// Generate imports for all non-weak dependencies, even if they are not
157	// referenced, because other code and tools depend on having the
158	// full transitive closure of protocol buffer types in the binary.
159	if !imp.IsWeak {
160		g.Import(impFile.GoImportPath)
161	}
162	if !imp.IsPublic {
163		return
164	}
165
166	// Generate public imports by generating the imported file, parsing it,
167	// and extracting every symbol that should receive a forwarding declaration.
168	impGen := GenerateFile(gen, impFile)
169	impGen.Skip()
170	b, err := impGen.Content()
171	if err != nil {
172		gen.Error(err)
173		return
174	}
175	fset := token.NewFileSet()
176	astFile, err := parser.ParseFile(fset, "", b, parser.ParseComments)
177	if err != nil {
178		gen.Error(err)
179		return
180	}
181	genForward := func(tok token.Token, name string, expr ast.Expr) {
182		// Don't import unexported symbols.
183		r, _ := utf8.DecodeRuneInString(name)
184		if !unicode.IsUpper(r) {
185			return
186		}
187		// Don't import the FileDescriptor.
188		if name == impFile.GoDescriptorIdent.GoName {
189			return
190		}
191		// Don't import decls referencing a symbol defined in another package.
192		// i.e., don't import decls which are themselves public imports:
193		//
194		//	type T = somepackage.T
195		if _, ok := expr.(*ast.SelectorExpr); ok {
196			return
197		}
198		g.P(tok, " ", name, " = ", impFile.GoImportPath.Ident(name))
199	}
200	g.P("// Symbols defined in public import of ", imp.Path(), ".")
201	g.P()
202	for _, decl := range astFile.Decls {
203		switch decl := decl.(type) {
204		case *ast.GenDecl:
205			for _, spec := range decl.Specs {
206				switch spec := spec.(type) {
207				case *ast.TypeSpec:
208					genForward(decl.Tok, spec.Name.Name, spec.Type)
209				case *ast.ValueSpec:
210					for i, name := range spec.Names {
211						var expr ast.Expr
212						if i < len(spec.Values) {
213							expr = spec.Values[i]
214						}
215						genForward(decl.Tok, name.Name, expr)
216					}
217				case *ast.ImportSpec:
218				default:
219					panic(fmt.Sprintf("can't generate forward for spec type %T", spec))
220				}
221			}
222		}
223	}
224	g.P()
225}
226
227func genEnum(g *protogen.GeneratedFile, f *fileInfo, e *enumInfo) {
228	// Enum type declaration.
229	g.Annotate(e.GoIdent.GoName, e.Location)
230	leadingComments := appendDeprecationSuffix(e.Comments.Leading,
231		e.Desc.ParentFile(),
232		e.Desc.Options().(*descriptorpb.EnumOptions).GetDeprecated())
233	g.P(leadingComments,
234		"type ", e.GoIdent, " int32")
235
236	// Enum value constants.
237	g.P("const (")
238	for _, value := range e.Values {
239		g.Annotate(value.GoIdent.GoName, value.Location)
240		leadingComments := appendDeprecationSuffix(value.Comments.Leading,
241			value.Desc.ParentFile(),
242			value.Desc.Options().(*descriptorpb.EnumValueOptions).GetDeprecated())
243		g.P(leadingComments,
244			value.GoIdent, " ", e.GoIdent, " = ", value.Desc.Number(),
245			trailingComment(value.Comments.Trailing))
246	}
247	g.P(")")
248	g.P()
249
250	// Enum value maps.
251	g.P("// Enum value maps for ", e.GoIdent, ".")
252	g.P("var (")
253	g.P(e.GoIdent.GoName+"_name", " = map[int32]string{")
254	for _, value := range e.Values {
255		duplicate := ""
256		if value.Desc != e.Desc.Values().ByNumber(value.Desc.Number()) {
257			duplicate = "// Duplicate value: "
258		}
259		g.P(duplicate, value.Desc.Number(), ": ", strconv.Quote(string(value.Desc.Name())), ",")
260	}
261	g.P("}")
262	g.P(e.GoIdent.GoName+"_value", " = map[string]int32{")
263	for _, value := range e.Values {
264		g.P(strconv.Quote(string(value.Desc.Name())), ": ", value.Desc.Number(), ",")
265	}
266	g.P("}")
267	g.P(")")
268	g.P()
269
270	// Enum method.
271	//
272	// NOTE: A pointer value is needed to represent presence in proto2.
273	// Since a proto2 message can reference a proto3 enum, it is useful to
274	// always generate this method (even on proto3 enums) to support that case.
275	g.P("func (x ", e.GoIdent, ") Enum() *", e.GoIdent, " {")
276	g.P("p := new(", e.GoIdent, ")")
277	g.P("*p = x")
278	g.P("return p")
279	g.P("}")
280	g.P()
281
282	// String method.
283	g.P("func (x ", e.GoIdent, ") String() string {")
284	g.P("return ", protoimplPackage.Ident("X"), ".EnumStringOf(x.Descriptor(), ", protoreflectPackage.Ident("EnumNumber"), "(x))")
285	g.P("}")
286	g.P()
287
288	genEnumReflectMethods(g, f, e)
289
290	// UnmarshalJSON method.
291	if e.genJSONMethod && e.Desc.Syntax() == protoreflect.Proto2 {
292		g.P("// Deprecated: Do not use.")
293		g.P("func (x *", e.GoIdent, ") UnmarshalJSON(b []byte) error {")
294		g.P("num, err := ", protoimplPackage.Ident("X"), ".UnmarshalJSONEnum(x.Descriptor(), b)")
295		g.P("if err != nil {")
296		g.P("return err")
297		g.P("}")
298		g.P("*x = ", e.GoIdent, "(num)")
299		g.P("return nil")
300		g.P("}")
301		g.P()
302	}
303
304	// EnumDescriptor method.
305	if e.genRawDescMethod {
306		var indexes []string
307		for i := 1; i < len(e.Location.Path); i += 2 {
308			indexes = append(indexes, strconv.Itoa(int(e.Location.Path[i])))
309		}
310		g.P("// Deprecated: Use ", e.GoIdent, ".Descriptor instead.")
311		g.P("func (", e.GoIdent, ") EnumDescriptor() ([]byte, []int) {")
312		g.P("return ", rawDescVarName(f), "GZIP(), []int{", strings.Join(indexes, ","), "}")
313		g.P("}")
314		g.P()
315		f.needRawDesc = true
316	}
317}
318
319func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
320	if m.Desc.IsMapEntry() {
321		return
322	}
323
324	// Message type declaration.
325	g.Annotate(m.GoIdent.GoName, m.Location)
326	leadingComments := appendDeprecationSuffix(m.Comments.Leading,
327		m.Desc.ParentFile(),
328		m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated())
329	g.P(leadingComments,
330		"type ", m.GoIdent, " struct {")
331	genMessageFields(g, f, m)
332	g.P("}")
333	g.P()
334
335	genMessageKnownFunctions(g, f, m)
336	genMessageDefaultDecls(g, f, m)
337	genMessageMethods(g, f, m)
338	genMessageOneofWrapperTypes(g, f, m)
339}
340
341func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
342	sf := f.allMessageFieldsByPtr[m]
343	genMessageInternalFields(g, f, m, sf)
344	for _, field := range m.Fields {
345		genMessageField(g, f, m, field, sf)
346	}
347}
348
349func genMessageInternalFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, sf *structFields) {
350	g.P(genid.State_goname, " ", protoimplPackage.Ident("MessageState"))
351	sf.append(genid.State_goname)
352	g.P(genid.SizeCache_goname, " ", protoimplPackage.Ident("SizeCache"))
353	sf.append(genid.SizeCache_goname)
354	if m.hasWeak {
355		g.P(genid.WeakFields_goname, " ", protoimplPackage.Ident("WeakFields"))
356		sf.append(genid.WeakFields_goname)
357	}
358	g.P(genid.UnknownFields_goname, " ", protoimplPackage.Ident("UnknownFields"))
359	sf.append(genid.UnknownFields_goname)
360	if m.Desc.ExtensionRanges().Len() > 0 {
361		g.P(genid.ExtensionFields_goname, " ", protoimplPackage.Ident("ExtensionFields"))
362		sf.append(genid.ExtensionFields_goname)
363	}
364	if sf.count > 0 {
365		g.P()
366	}
367}
368
369func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields) {
370	if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() {
371		// It would be a bit simpler to iterate over the oneofs below,
372		// but generating the field here keeps the contents of the Go
373		// struct in the same order as the contents of the source
374		// .proto file.
375		if oneof.Fields[0] != field {
376			return // only generate for first appearance
377		}
378
379		tags := structTags{
380			{"protobuf_oneof", string(oneof.Desc.Name())},
381		}
382		if m.isTracked {
383			tags = append(tags, gotrackTags...)
384		}
385
386		g.Annotate(m.GoIdent.GoName+"."+oneof.GoName, oneof.Location)
387		leadingComments := oneof.Comments.Leading
388		if leadingComments != "" {
389			leadingComments += "\n"
390		}
391		ss := []string{fmt.Sprintf(" Types that are assignable to %s:\n", oneof.GoName)}
392		for _, field := range oneof.Fields {
393			ss = append(ss, "\t*"+field.GoIdent.GoName+"\n")
394		}
395		leadingComments += protogen.Comments(strings.Join(ss, ""))
396		g.P(leadingComments,
397			oneof.GoName, " ", oneofInterfaceName(oneof), tags)
398		sf.append(oneof.GoName)
399		return
400	}
401	goType, pointer := fieldGoType(g, f, field)
402	if pointer {
403		goType = "*" + goType
404	}
405	tags := structTags{
406		{"protobuf", fieldProtobufTagValue(field)},
407		{"json", fieldJSONTagValue(field)},
408	}
409	if field.Desc.IsMap() {
410		key := field.Message.Fields[0]
411		val := field.Message.Fields[1]
412		tags = append(tags, structTags{
413			{"protobuf_key", fieldProtobufTagValue(key)},
414			{"protobuf_val", fieldProtobufTagValue(val)},
415		}...)
416	}
417	if m.isTracked {
418		tags = append(tags, gotrackTags...)
419	}
420
421	name := field.GoName
422	if field.Desc.IsWeak() {
423		name = genid.WeakFieldPrefix_goname + name
424	}
425	g.Annotate(m.GoIdent.GoName+"."+name, field.Location)
426	leadingComments := appendDeprecationSuffix(field.Comments.Leading,
427		field.Desc.ParentFile(),
428		field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
429	g.P(leadingComments,
430		name, " ", goType, tags,
431		trailingComment(field.Comments.Trailing))
432	sf.append(field.GoName)
433}
434
435// genMessageDefaultDecls generates consts and vars holding the default
436// values of fields.
437func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
438	var consts, vars []string
439	for _, field := range m.Fields {
440		if !field.Desc.HasDefault() {
441			continue
442		}
443		name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
444		goType, _ := fieldGoType(g, f, field)
445		defVal := field.Desc.Default()
446		switch field.Desc.Kind() {
447		case protoreflect.StringKind:
448			consts = append(consts, fmt.Sprintf("%s = %s(%q)", name, goType, defVal.String()))
449		case protoreflect.BytesKind:
450			vars = append(vars, fmt.Sprintf("%s = %s(%q)", name, goType, defVal.Bytes()))
451		case protoreflect.EnumKind:
452			idx := field.Desc.DefaultEnumValue().Index()
453			val := field.Enum.Values[idx]
454			if val.GoIdent.GoImportPath == f.GoImportPath {
455				consts = append(consts, fmt.Sprintf("%s = %s", name, g.QualifiedGoIdent(val.GoIdent)))
456			} else {
457				// If the enum value is declared in a different Go package,
458				// reference it by number since the name may not be correct.
459				// See https://github.com/golang/protobuf/issues/513.
460				consts = append(consts, fmt.Sprintf("%s = %s(%d) // %s",
461					name, g.QualifiedGoIdent(field.Enum.GoIdent), val.Desc.Number(), g.QualifiedGoIdent(val.GoIdent)))
462			}
463		case protoreflect.FloatKind, protoreflect.DoubleKind:
464			if f := defVal.Float(); math.IsNaN(f) || math.IsInf(f, 0) {
465				var fn, arg string
466				switch f := defVal.Float(); {
467				case math.IsInf(f, -1):
468					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("Inf")), "-1"
469				case math.IsInf(f, +1):
470					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("Inf")), "+1"
471				case math.IsNaN(f):
472					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("NaN")), ""
473				}
474				vars = append(vars, fmt.Sprintf("%s = %s(%s(%s))", name, goType, fn, arg))
475			} else {
476				consts = append(consts, fmt.Sprintf("%s = %s(%v)", name, goType, f))
477			}
478		default:
479			consts = append(consts, fmt.Sprintf("%s = %s(%v)", name, goType, defVal.Interface()))
480		}
481	}
482	if len(consts) > 0 {
483		g.P("// Default values for ", m.GoIdent, " fields.")
484		g.P("const (")
485		for _, s := range consts {
486			g.P(s)
487		}
488		g.P(")")
489	}
490	if len(vars) > 0 {
491		g.P("// Default values for ", m.GoIdent, " fields.")
492		g.P("var (")
493		for _, s := range vars {
494			g.P(s)
495		}
496		g.P(")")
497	}
498	g.P()
499}
500
501func genMessageMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
502	genMessageBaseMethods(g, f, m)
503	genMessageGetterMethods(g, f, m)
504	genMessageSetterMethods(g, f, m)
505}
506
507func genMessageBaseMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
508	// Reset method.
509	g.P("func (x *", m.GoIdent, ") Reset() {")
510	g.P("*x = ", m.GoIdent, "{}")
511	g.P("if ", protoimplPackage.Ident("UnsafeEnabled"), " {")
512	g.P("mi := &", messageTypesVarName(f), "[", f.allMessagesByPtr[m], "]")
513	g.P("ms := ", protoimplPackage.Ident("X"), ".MessageStateOf(", protoimplPackage.Ident("Pointer"), "(x))")
514	g.P("ms.StoreMessageInfo(mi)")
515	g.P("}")
516	g.P("}")
517	g.P()
518
519	// String method.
520	g.P("func (x *", m.GoIdent, ") String() string {")
521	g.P("return ", protoimplPackage.Ident("X"), ".MessageStringOf(x)")
522	g.P("}")
523	g.P()
524
525	// ProtoMessage method.
526	g.P("func (*", m.GoIdent, ") ProtoMessage() {}")
527	g.P()
528
529	// ProtoReflect method.
530	genMessageReflectMethods(g, f, m)
531
532	// Descriptor method.
533	if m.genRawDescMethod {
534		var indexes []string
535		for i := 1; i < len(m.Location.Path); i += 2 {
536			indexes = append(indexes, strconv.Itoa(int(m.Location.Path[i])))
537		}
538		g.P("// Deprecated: Use ", m.GoIdent, ".ProtoReflect.Descriptor instead.")
539		g.P("func (*", m.GoIdent, ") Descriptor() ([]byte, []int) {")
540		g.P("return ", rawDescVarName(f), "GZIP(), []int{", strings.Join(indexes, ","), "}")
541		g.P("}")
542		g.P()
543		f.needRawDesc = true
544	}
545}
546
547func genMessageGetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
548	for _, field := range m.Fields {
549		genNoInterfacePragma(g, m.isTracked)
550
551		// Getter for parent oneof.
552		if oneof := field.Oneof; oneof != nil && oneof.Fields[0] == field && !oneof.Desc.IsSynthetic() {
553			g.Annotate(m.GoIdent.GoName+".Get"+oneof.GoName, oneof.Location)
554			g.P("func (m *", m.GoIdent.GoName, ") Get", oneof.GoName, "() ", oneofInterfaceName(oneof), " {")
555			g.P("if m != nil {")
556			g.P("return m.", oneof.GoName)
557			g.P("}")
558			g.P("return nil")
559			g.P("}")
560			g.P()
561		}
562
563		// Getter for message field.
564		goType, pointer := fieldGoType(g, f, field)
565		defaultValue := fieldDefaultValue(g, f, m, field)
566		g.Annotate(m.GoIdent.GoName+".Get"+field.GoName, field.Location)
567		leadingComments := appendDeprecationSuffix("",
568			field.Desc.ParentFile(),
569			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
570		switch {
571		case field.Desc.IsWeak():
572			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", protoPackage.Ident("Message"), "{")
573			g.P("var w ", protoimplPackage.Ident("WeakFields"))
574			g.P("if x != nil {")
575			g.P("w = x.", genid.WeakFields_goname)
576			if m.isTracked {
577				g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName)
578			}
579			g.P("}")
580			g.P("return ", protoimplPackage.Ident("X"), ".GetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ")")
581			g.P("}")
582		case field.Oneof != nil && !field.Oneof.Desc.IsSynthetic():
583			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", goType, " {")
584			g.P("if x, ok := x.Get", field.Oneof.GoName, "().(*", field.GoIdent, "); ok {")
585			g.P("return x.", field.GoName)
586			g.P("}")
587			g.P("return ", defaultValue)
588			g.P("}")
589		default:
590			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", goType, " {")
591			if !field.Desc.HasPresence() || defaultValue == "nil" {
592				g.P("if x != nil {")
593			} else {
594				g.P("if x != nil && x.", field.GoName, " != nil {")
595			}
596			star := ""
597			if pointer {
598				star = "*"
599			}
600			g.P("return ", star, " x.", field.GoName)
601			g.P("}")
602			g.P("return ", defaultValue)
603			g.P("}")
604		}
605		g.P()
606	}
607}
608
609func genMessageSetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
610	for _, field := range m.Fields {
611		if !field.Desc.IsWeak() {
612			continue
613		}
614
615		genNoInterfacePragma(g, m.isTracked)
616
617		g.Annotate(m.GoIdent.GoName+".Set"+field.GoName, field.Location)
618		leadingComments := appendDeprecationSuffix("",
619			field.Desc.ParentFile(),
620			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
621		g.P(leadingComments, "func (x *", m.GoIdent, ") Set", field.GoName, "(v ", protoPackage.Ident("Message"), ") {")
622		g.P("var w *", protoimplPackage.Ident("WeakFields"))
623		g.P("if x != nil {")
624		g.P("w = &x.", genid.WeakFields_goname)
625		if m.isTracked {
626			g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName)
627		}
628		g.P("}")
629		g.P(protoimplPackage.Ident("X"), ".SetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ", v)")
630		g.P("}")
631		g.P()
632	}
633}
634
635// fieldGoType returns the Go type used for a field.
636//
637// If it returns pointer=true, the struct field is a pointer to the type.
638func fieldGoType(g *protogen.GeneratedFile, f *fileInfo, field *protogen.Field) (goType string, pointer bool) {
639	if field.Desc.IsWeak() {
640		return "struct{}", false
641	}
642
643	pointer = field.Desc.HasPresence()
644	switch field.Desc.Kind() {
645	case protoreflect.BoolKind:
646		goType = "bool"
647	case protoreflect.EnumKind:
648		goType = g.QualifiedGoIdent(field.Enum.GoIdent)
649	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
650		goType = "int32"
651	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
652		goType = "uint32"
653	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
654		goType = "int64"
655	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
656		goType = "uint64"
657	case protoreflect.FloatKind:
658		goType = "float32"
659	case protoreflect.DoubleKind:
660		goType = "float64"
661	case protoreflect.StringKind:
662		goType = "string"
663	case protoreflect.BytesKind:
664		goType = "[]byte"
665		pointer = false // rely on nullability of slices for presence
666	case protoreflect.MessageKind, protoreflect.GroupKind:
667		goType = "*" + g.QualifiedGoIdent(field.Message.GoIdent)
668		pointer = false // pointer captured as part of the type
669	}
670	switch {
671	case field.Desc.IsList():
672		return "[]" + goType, false
673	case field.Desc.IsMap():
674		keyType, _ := fieldGoType(g, f, field.Message.Fields[0])
675		valType, _ := fieldGoType(g, f, field.Message.Fields[1])
676		return fmt.Sprintf("map[%v]%v", keyType, valType), false
677	}
678	return goType, pointer
679}
680
681func fieldProtobufTagValue(field *protogen.Field) string {
682	var enumName string
683	if field.Desc.Kind() == protoreflect.EnumKind {
684		enumName = protoimpl.X.LegacyEnumName(field.Enum.Desc)
685	}
686	return tag.Marshal(field.Desc, enumName)
687}
688
689func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field) string {
690	if field.Desc.IsList() {
691		return "nil"
692	}
693	if field.Desc.HasDefault() {
694		defVarName := "Default_" + m.GoIdent.GoName + "_" + field.GoName
695		if field.Desc.Kind() == protoreflect.BytesKind {
696			return "append([]byte(nil), " + defVarName + "...)"
697		}
698		return defVarName
699	}
700	switch field.Desc.Kind() {
701	case protoreflect.BoolKind:
702		return "false"
703	case protoreflect.StringKind:
704		return `""`
705	case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.BytesKind:
706		return "nil"
707	case protoreflect.EnumKind:
708		val := field.Enum.Values[0]
709		if val.GoIdent.GoImportPath == f.GoImportPath {
710			return g.QualifiedGoIdent(val.GoIdent)
711		} else {
712			// If the enum value is declared in a different Go package,
713			// reference it by number since the name may not be correct.
714			// See https://github.com/golang/protobuf/issues/513.
715			return g.QualifiedGoIdent(field.Enum.GoIdent) + "(" + strconv.FormatInt(int64(val.Desc.Number()), 10) + ")"
716		}
717	default:
718		return "0"
719	}
720}
721
722func fieldJSONTagValue(field *protogen.Field) string {
723	return string(field.Desc.Name()) + ",omitempty"
724}
725
726func genExtensions(g *protogen.GeneratedFile, f *fileInfo) {
727	if len(f.allExtensions) == 0 {
728		return
729	}
730
731	g.P("var ", extensionTypesVarName(f), " = []", protoimplPackage.Ident("ExtensionInfo"), "{")
732	for _, x := range f.allExtensions {
733		g.P("{")
734		g.P("ExtendedType: (*", x.Extendee.GoIdent, ")(nil),")
735		goType, pointer := fieldGoType(g, f, x.Extension)
736		if pointer {
737			goType = "*" + goType
738		}
739		g.P("ExtensionType: (", goType, ")(nil),")
740		g.P("Field: ", x.Desc.Number(), ",")
741		g.P("Name: ", strconv.Quote(string(x.Desc.FullName())), ",")
742		g.P("Tag: ", strconv.Quote(fieldProtobufTagValue(x.Extension)), ",")
743		g.P("Filename: ", strconv.Quote(f.Desc.Path()), ",")
744		g.P("},")
745	}
746	g.P("}")
747	g.P()
748
749	// Group extensions by the target message.
750	var orderedTargets []protogen.GoIdent
751	allExtensionsByTarget := make(map[protogen.GoIdent][]*extensionInfo)
752	allExtensionsByPtr := make(map[*extensionInfo]int)
753	for i, x := range f.allExtensions {
754		target := x.Extendee.GoIdent
755		if len(allExtensionsByTarget[target]) == 0 {
756			orderedTargets = append(orderedTargets, target)
757		}
758		allExtensionsByTarget[target] = append(allExtensionsByTarget[target], x)
759		allExtensionsByPtr[x] = i
760	}
761	for _, target := range orderedTargets {
762		g.P("// Extension fields to ", target, ".")
763		g.P("var (")
764		for _, x := range allExtensionsByTarget[target] {
765			xd := x.Desc
766			typeName := xd.Kind().String()
767			switch xd.Kind() {
768			case protoreflect.EnumKind:
769				typeName = string(xd.Enum().FullName())
770			case protoreflect.MessageKind, protoreflect.GroupKind:
771				typeName = string(xd.Message().FullName())
772			}
773			fieldName := string(xd.Name())
774
775			leadingComments := x.Comments.Leading
776			if leadingComments != "" {
777				leadingComments += "\n"
778			}
779			leadingComments += protogen.Comments(fmt.Sprintf(" %v %v %v = %v;\n",
780				xd.Cardinality(), typeName, fieldName, xd.Number()))
781			leadingComments = appendDeprecationSuffix(leadingComments,
782				x.Desc.ParentFile(),
783				x.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
784			g.P(leadingComments,
785				"E_", x.GoIdent, " = &", extensionTypesVarName(f), "[", allExtensionsByPtr[x], "]",
786				trailingComment(x.Comments.Trailing))
787		}
788		g.P(")")
789		g.P()
790	}
791}
792
793// genMessageOneofWrapperTypes generates the oneof wrapper types and
794// associates the types with the parent message type.
795func genMessageOneofWrapperTypes(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
796	for _, oneof := range m.Oneofs {
797		if oneof.Desc.IsSynthetic() {
798			continue
799		}
800		ifName := oneofInterfaceName(oneof)
801		g.P("type ", ifName, " interface {")
802		g.P(ifName, "()")
803		g.P("}")
804		g.P()
805		for _, field := range oneof.Fields {
806			g.Annotate(field.GoIdent.GoName, field.Location)
807			g.Annotate(field.GoIdent.GoName+"."+field.GoName, field.Location)
808			g.P("type ", field.GoIdent, " struct {")
809			goType, _ := fieldGoType(g, f, field)
810			tags := structTags{
811				{"protobuf", fieldProtobufTagValue(field)},
812			}
813			if m.isTracked {
814				tags = append(tags, gotrackTags...)
815			}
816			leadingComments := appendDeprecationSuffix(field.Comments.Leading,
817				field.Desc.ParentFile(),
818				field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
819			g.P(leadingComments,
820				field.GoName, " ", goType, tags,
821				trailingComment(field.Comments.Trailing))
822			g.P("}")
823			g.P()
824		}
825		for _, field := range oneof.Fields {
826			g.P("func (*", field.GoIdent, ") ", ifName, "() {}")
827			g.P()
828		}
829	}
830}
831
832// oneofInterfaceName returns the name of the interface type implemented by
833// the oneof field value types.
834func oneofInterfaceName(oneof *protogen.Oneof) string {
835	return "is" + oneof.GoIdent.GoName
836}
837
838// genNoInterfacePragma generates a standalone "nointerface" pragma to
839// decorate methods with field-tracking support.
840func genNoInterfacePragma(g *protogen.GeneratedFile, tracked bool) {
841	if tracked {
842		g.P("//go:nointerface")
843		g.P()
844	}
845}
846
847var gotrackTags = structTags{{"go", "track"}}
848
849// structTags is a data structure for build idiomatic Go struct tags.
850// Each [2]string is a key-value pair, where value is the unescaped string.
851//
852// Example: structTags{{"key", "value"}}.String() -> `key:"value"`
853type structTags [][2]string
854
855func (tags structTags) String() string {
856	if len(tags) == 0 {
857		return ""
858	}
859	var ss []string
860	for _, tag := range tags {
861		// NOTE: When quoting the value, we need to make sure the backtick
862		// character does not appear. Convert all cases to the escaped hex form.
863		key := tag[0]
864		val := strings.Replace(strconv.Quote(tag[1]), "`", `\x60`, -1)
865		ss = append(ss, fmt.Sprintf("%s:%s", key, val))
866	}
867	return "`" + strings.Join(ss, " ") + "`"
868}
869
870// appendDeprecationSuffix optionally appends a deprecation notice as a suffix.
871func appendDeprecationSuffix(prefix protogen.Comments, parentFile protoreflect.FileDescriptor, deprecated bool) protogen.Comments {
872	fileDeprecated := parentFile.Options().(*descriptorpb.FileOptions).GetDeprecated()
873	if !deprecated && !fileDeprecated {
874		return prefix
875	}
876	if prefix != "" {
877		prefix += "\n"
878	}
879	if fileDeprecated {
880		return prefix + " Deprecated: The entire proto file " + protogen.Comments(parentFile.Path()) + " is marked as deprecated.\n"
881	}
882	return prefix + " Deprecated: Marked as deprecated in " + protogen.Comments(parentFile.Path()) + ".\n"
883}
884
885// trailingComment is like protogen.Comments, but lacks a trailing newline.
886type trailingComment protogen.Comments
887
888func (c trailingComment) String() string {
889	s := strings.TrimSuffix(protogen.Comments(c).String(), "\n")
890	if strings.Contains(s, "\n") {
891		// We don't support multi-lined trailing comments as it is unclear
892		// how to best render them in the generated code.
893		return ""
894	}
895	return s
896}
897