1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ast
6
7import (
8	"go/token"
9	"slices"
10)
11
12// ----------------------------------------------------------------------------
13// Export filtering
14
15// exportFilter is a special filter function to extract exported nodes.
16func exportFilter(name string) bool {
17	return IsExported(name)
18}
19
20// FileExports trims the AST for a Go source file in place such that
21// only exported nodes remain: all top-level identifiers which are not exported
22// and their associated information (such as type, initial value, or function
23// body) are removed. Non-exported fields and methods of exported types are
24// stripped. The [File.Comments] list is not changed.
25//
26// FileExports reports whether there are exported declarations.
27func FileExports(src *File) bool {
28	return filterFile(src, exportFilter, true)
29}
30
31// PackageExports trims the AST for a Go package in place such that
32// only exported nodes remain. The pkg.Files list is not changed, so that
33// file names and top-level package comments don't get lost.
34//
35// PackageExports reports whether there are exported declarations;
36// it returns false otherwise.
37func PackageExports(pkg *Package) bool {
38	return filterPackage(pkg, exportFilter, true)
39}
40
41// ----------------------------------------------------------------------------
42// General filtering
43
44type Filter func(string) bool
45
46func filterIdentList(list []*Ident, f Filter) []*Ident {
47	j := 0
48	for _, x := range list {
49		if f(x.Name) {
50			list[j] = x
51			j++
52		}
53	}
54	return list[0:j]
55}
56
57// fieldName assumes that x is the type of an anonymous field and
58// returns the corresponding field name. If x is not an acceptable
59// anonymous field, the result is nil.
60func fieldName(x Expr) *Ident {
61	switch t := x.(type) {
62	case *Ident:
63		return t
64	case *SelectorExpr:
65		if _, ok := t.X.(*Ident); ok {
66			return t.Sel
67		}
68	case *StarExpr:
69		return fieldName(t.X)
70	}
71	return nil
72}
73
74func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
75	if fields == nil {
76		return false
77	}
78	list := fields.List
79	j := 0
80	for _, f := range list {
81		keepField := false
82		if len(f.Names) == 0 {
83			// anonymous field
84			name := fieldName(f.Type)
85			keepField = name != nil && filter(name.Name)
86		} else {
87			n := len(f.Names)
88			f.Names = filterIdentList(f.Names, filter)
89			if len(f.Names) < n {
90				removedFields = true
91			}
92			keepField = len(f.Names) > 0
93		}
94		if keepField {
95			if export {
96				filterType(f.Type, filter, export)
97			}
98			list[j] = f
99			j++
100		}
101	}
102	if j < len(list) {
103		removedFields = true
104	}
105	fields.List = list[0:j]
106	return
107}
108
109func filterCompositeLit(lit *CompositeLit, filter Filter, export bool) {
110	n := len(lit.Elts)
111	lit.Elts = filterExprList(lit.Elts, filter, export)
112	if len(lit.Elts) < n {
113		lit.Incomplete = true
114	}
115}
116
117func filterExprList(list []Expr, filter Filter, export bool) []Expr {
118	j := 0
119	for _, exp := range list {
120		switch x := exp.(type) {
121		case *CompositeLit:
122			filterCompositeLit(x, filter, export)
123		case *KeyValueExpr:
124			if x, ok := x.Key.(*Ident); ok && !filter(x.Name) {
125				continue
126			}
127			if x, ok := x.Value.(*CompositeLit); ok {
128				filterCompositeLit(x, filter, export)
129			}
130		}
131		list[j] = exp
132		j++
133	}
134	return list[0:j]
135}
136
137func filterParamList(fields *FieldList, filter Filter, export bool) bool {
138	if fields == nil {
139		return false
140	}
141	var b bool
142	for _, f := range fields.List {
143		if filterType(f.Type, filter, export) {
144			b = true
145		}
146	}
147	return b
148}
149
150func filterType(typ Expr, f Filter, export bool) bool {
151	switch t := typ.(type) {
152	case *Ident:
153		return f(t.Name)
154	case *ParenExpr:
155		return filterType(t.X, f, export)
156	case *ArrayType:
157		return filterType(t.Elt, f, export)
158	case *StructType:
159		if filterFieldList(t.Fields, f, export) {
160			t.Incomplete = true
161		}
162		return len(t.Fields.List) > 0
163	case *FuncType:
164		b1 := filterParamList(t.Params, f, export)
165		b2 := filterParamList(t.Results, f, export)
166		return b1 || b2
167	case *InterfaceType:
168		if filterFieldList(t.Methods, f, export) {
169			t.Incomplete = true
170		}
171		return len(t.Methods.List) > 0
172	case *MapType:
173		b1 := filterType(t.Key, f, export)
174		b2 := filterType(t.Value, f, export)
175		return b1 || b2
176	case *ChanType:
177		return filterType(t.Value, f, export)
178	}
179	return false
180}
181
182func filterSpec(spec Spec, f Filter, export bool) bool {
183	switch s := spec.(type) {
184	case *ValueSpec:
185		s.Names = filterIdentList(s.Names, f)
186		s.Values = filterExprList(s.Values, f, export)
187		if len(s.Names) > 0 {
188			if export {
189				filterType(s.Type, f, export)
190			}
191			return true
192		}
193	case *TypeSpec:
194		if f(s.Name.Name) {
195			if export {
196				filterType(s.Type, f, export)
197			}
198			return true
199		}
200		if !export {
201			// For general filtering (not just exports),
202			// filter type even if name is not filtered
203			// out.
204			// If the type contains filtered elements,
205			// keep the declaration.
206			return filterType(s.Type, f, export)
207		}
208	}
209	return false
210}
211
212func filterSpecList(list []Spec, f Filter, export bool) []Spec {
213	j := 0
214	for _, s := range list {
215		if filterSpec(s, f, export) {
216			list[j] = s
217			j++
218		}
219	}
220	return list[0:j]
221}
222
223// FilterDecl trims the AST for a Go declaration in place by removing
224// all names (including struct field and interface method names, but
225// not from parameter lists) that don't pass through the filter f.
226//
227// FilterDecl reports whether there are any declared names left after
228// filtering.
229func FilterDecl(decl Decl, f Filter) bool {
230	return filterDecl(decl, f, false)
231}
232
233func filterDecl(decl Decl, f Filter, export bool) bool {
234	switch d := decl.(type) {
235	case *GenDecl:
236		d.Specs = filterSpecList(d.Specs, f, export)
237		return len(d.Specs) > 0
238	case *FuncDecl:
239		return f(d.Name.Name)
240	}
241	return false
242}
243
244// FilterFile trims the AST for a Go file in place by removing all
245// names from top-level declarations (including struct field and
246// interface method names, but not from parameter lists) that don't
247// pass through the filter f. If the declaration is empty afterwards,
248// the declaration is removed from the AST. Import declarations are
249// always removed. The [File.Comments] list is not changed.
250//
251// FilterFile reports whether there are any top-level declarations
252// left after filtering.
253func FilterFile(src *File, f Filter) bool {
254	return filterFile(src, f, false)
255}
256
257func filterFile(src *File, f Filter, export bool) bool {
258	j := 0
259	for _, d := range src.Decls {
260		if filterDecl(d, f, export) {
261			src.Decls[j] = d
262			j++
263		}
264	}
265	src.Decls = src.Decls[0:j]
266	return j > 0
267}
268
269// FilterPackage trims the AST for a Go package in place by removing
270// all names from top-level declarations (including struct field and
271// interface method names, but not from parameter lists) that don't
272// pass through the filter f. If the declaration is empty afterwards,
273// the declaration is removed from the AST. The pkg.Files list is not
274// changed, so that file names and top-level package comments don't get
275// lost.
276//
277// FilterPackage reports whether there are any top-level declarations
278// left after filtering.
279func FilterPackage(pkg *Package, f Filter) bool {
280	return filterPackage(pkg, f, false)
281}
282
283func filterPackage(pkg *Package, f Filter, export bool) bool {
284	hasDecls := false
285	for _, src := range pkg.Files {
286		if filterFile(src, f, export) {
287			hasDecls = true
288		}
289	}
290	return hasDecls
291}
292
293// ----------------------------------------------------------------------------
294// Merging of package files
295
296// The MergeMode flags control the behavior of [MergePackageFiles].
297type MergeMode uint
298
299const (
300	// If set, duplicate function declarations are excluded.
301	FilterFuncDuplicates MergeMode = 1 << iota
302	// If set, comments that are not associated with a specific
303	// AST node (as Doc or Comment) are excluded.
304	FilterUnassociatedComments
305	// If set, duplicate import declarations are excluded.
306	FilterImportDuplicates
307)
308
309// nameOf returns the function (foo) or method name (foo.bar) for
310// the given function declaration. If the AST is incorrect for the
311// receiver, it assumes a function instead.
312func nameOf(f *FuncDecl) string {
313	if r := f.Recv; r != nil && len(r.List) == 1 {
314		// looks like a correct receiver declaration
315		t := r.List[0].Type
316		// dereference pointer receiver types
317		if p, _ := t.(*StarExpr); p != nil {
318			t = p.X
319		}
320		// the receiver type must be a type name
321		if p, _ := t.(*Ident); p != nil {
322			return p.Name + "." + f.Name.Name
323		}
324		// otherwise assume a function instead
325	}
326	return f.Name.Name
327}
328
329// separator is an empty //-style comment that is interspersed between
330// different comment groups when they are concatenated into a single group
331var separator = &Comment{token.NoPos, "//"}
332
333// MergePackageFiles creates a file AST by merging the ASTs of the
334// files belonging to a package. The mode flags control merging behavior.
335func MergePackageFiles(pkg *Package, mode MergeMode) *File {
336	// Count the number of package docs, comments and declarations across
337	// all package files. Also, compute sorted list of filenames, so that
338	// subsequent iterations can always iterate in the same order.
339	ndocs := 0
340	ncomments := 0
341	ndecls := 0
342	filenames := make([]string, len(pkg.Files))
343	var minPos, maxPos token.Pos
344	i := 0
345	for filename, f := range pkg.Files {
346		filenames[i] = filename
347		i++
348		if f.Doc != nil {
349			ndocs += len(f.Doc.List) + 1 // +1 for separator
350		}
351		ncomments += len(f.Comments)
352		ndecls += len(f.Decls)
353		if i == 0 || f.FileStart < minPos {
354			minPos = f.FileStart
355		}
356		if i == 0 || f.FileEnd > maxPos {
357			maxPos = f.FileEnd
358		}
359	}
360	slices.Sort(filenames)
361
362	// Collect package comments from all package files into a single
363	// CommentGroup - the collected package documentation. In general
364	// there should be only one file with a package comment; but it's
365	// better to collect extra comments than drop them on the floor.
366	var doc *CommentGroup
367	var pos token.Pos
368	if ndocs > 0 {
369		list := make([]*Comment, ndocs-1) // -1: no separator before first group
370		i := 0
371		for _, filename := range filenames {
372			f := pkg.Files[filename]
373			if f.Doc != nil {
374				if i > 0 {
375					// not the first group - add separator
376					list[i] = separator
377					i++
378				}
379				for _, c := range f.Doc.List {
380					list[i] = c
381					i++
382				}
383				if f.Package > pos {
384					// Keep the maximum package clause position as
385					// position for the package clause of the merged
386					// files.
387					pos = f.Package
388				}
389			}
390		}
391		doc = &CommentGroup{list}
392	}
393
394	// Collect declarations from all package files.
395	var decls []Decl
396	if ndecls > 0 {
397		decls = make([]Decl, ndecls)
398		funcs := make(map[string]int) // map of func name -> decls index
399		i := 0                        // current index
400		n := 0                        // number of filtered entries
401		for _, filename := range filenames {
402			f := pkg.Files[filename]
403			for _, d := range f.Decls {
404				if mode&FilterFuncDuplicates != 0 {
405					// A language entity may be declared multiple
406					// times in different package files; only at
407					// build time declarations must be unique.
408					// For now, exclude multiple declarations of
409					// functions - keep the one with documentation.
410					//
411					// TODO(gri): Expand this filtering to other
412					//            entities (const, type, vars) if
413					//            multiple declarations are common.
414					if f, isFun := d.(*FuncDecl); isFun {
415						name := nameOf(f)
416						if j, exists := funcs[name]; exists {
417							// function declared already
418							if decls[j] != nil && decls[j].(*FuncDecl).Doc == nil {
419								// existing declaration has no documentation;
420								// ignore the existing declaration
421								decls[j] = nil
422							} else {
423								// ignore the new declaration
424								d = nil
425							}
426							n++ // filtered an entry
427						} else {
428							funcs[name] = i
429						}
430					}
431				}
432				decls[i] = d
433				i++
434			}
435		}
436
437		// Eliminate nil entries from the decls list if entries were
438		// filtered. We do this using a 2nd pass in order to not disturb
439		// the original declaration order in the source (otherwise, this
440		// would also invalidate the monotonically increasing position
441		// info within a single file).
442		if n > 0 {
443			i = 0
444			for _, d := range decls {
445				if d != nil {
446					decls[i] = d
447					i++
448				}
449			}
450			decls = decls[0:i]
451		}
452	}
453
454	// Collect import specs from all package files.
455	var imports []*ImportSpec
456	if mode&FilterImportDuplicates != 0 {
457		seen := make(map[string]bool)
458		for _, filename := range filenames {
459			f := pkg.Files[filename]
460			for _, imp := range f.Imports {
461				if path := imp.Path.Value; !seen[path] {
462					// TODO: consider handling cases where:
463					// - 2 imports exist with the same import path but
464					//   have different local names (one should probably
465					//   keep both of them)
466					// - 2 imports exist but only one has a comment
467					// - 2 imports exist and they both have (possibly
468					//   different) comments
469					imports = append(imports, imp)
470					seen[path] = true
471				}
472			}
473		}
474	} else {
475		// Iterate over filenames for deterministic order.
476		for _, filename := range filenames {
477			f := pkg.Files[filename]
478			imports = append(imports, f.Imports...)
479		}
480	}
481
482	// Collect comments from all package files.
483	var comments []*CommentGroup
484	if mode&FilterUnassociatedComments == 0 {
485		comments = make([]*CommentGroup, ncomments)
486		i := 0
487		for _, filename := range filenames {
488			f := pkg.Files[filename]
489			i += copy(comments[i:], f.Comments)
490		}
491	}
492
493	// TODO(gri) need to compute unresolved identifiers!
494	return &File{doc, pos, NewIdent(pkg.Name), decls, minPos, maxPos, pkg.Scope, imports, nil, comments, ""}
495}
496