1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"fmt"
9	"go/ast"
10	"go/token"
11	"path"
12	"strconv"
13)
14
15type fix struct {
16	name     string
17	date     string // date that fix was introduced, in YYYY-MM-DD format
18	f        func(*ast.File) bool
19	desc     string
20	disabled bool // whether this fix should be disabled by default
21}
22
23// main runs sort.Sort(byName(fixes)) before printing list of fixes.
24type byName []fix
25
26func (f byName) Len() int           { return len(f) }
27func (f byName) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
28func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
29
30// main runs sort.Sort(byDate(fixes)) before applying fixes.
31type byDate []fix
32
33func (f byDate) Len() int           { return len(f) }
34func (f byDate) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
35func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
36
37var fixes []fix
38
39func register(f fix) {
40	fixes = append(fixes, f)
41}
42
43// walk traverses the AST x, calling visit(y) for each node y in the tree but
44// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
45// in a bottom-up traversal.
46func walk(x any, visit func(any)) {
47	walkBeforeAfter(x, nop, visit)
48}
49
50func nop(any) {}
51
52// walkBeforeAfter is like walk but calls before(x) before traversing
53// x's children and after(x) afterward.
54func walkBeforeAfter(x any, before, after func(any)) {
55	before(x)
56
57	switch n := x.(type) {
58	default:
59		panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
60
61	case nil:
62
63	// pointers to interfaces
64	case *ast.Decl:
65		walkBeforeAfter(*n, before, after)
66	case *ast.Expr:
67		walkBeforeAfter(*n, before, after)
68	case *ast.Spec:
69		walkBeforeAfter(*n, before, after)
70	case *ast.Stmt:
71		walkBeforeAfter(*n, before, after)
72
73	// pointers to struct pointers
74	case **ast.BlockStmt:
75		walkBeforeAfter(*n, before, after)
76	case **ast.CallExpr:
77		walkBeforeAfter(*n, before, after)
78	case **ast.FieldList:
79		walkBeforeAfter(*n, before, after)
80	case **ast.FuncType:
81		walkBeforeAfter(*n, before, after)
82	case **ast.Ident:
83		walkBeforeAfter(*n, before, after)
84	case **ast.BasicLit:
85		walkBeforeAfter(*n, before, after)
86
87	// pointers to slices
88	case *[]ast.Decl:
89		walkBeforeAfter(*n, before, after)
90	case *[]ast.Expr:
91		walkBeforeAfter(*n, before, after)
92	case *[]*ast.File:
93		walkBeforeAfter(*n, before, after)
94	case *[]*ast.Ident:
95		walkBeforeAfter(*n, before, after)
96	case *[]ast.Spec:
97		walkBeforeAfter(*n, before, after)
98	case *[]ast.Stmt:
99		walkBeforeAfter(*n, before, after)
100
101	// These are ordered and grouped to match ../../go/ast/ast.go
102	case *ast.Field:
103		walkBeforeAfter(&n.Names, before, after)
104		walkBeforeAfter(&n.Type, before, after)
105		walkBeforeAfter(&n.Tag, before, after)
106	case *ast.FieldList:
107		for _, field := range n.List {
108			walkBeforeAfter(field, before, after)
109		}
110	case *ast.BadExpr:
111	case *ast.Ident:
112	case *ast.Ellipsis:
113		walkBeforeAfter(&n.Elt, before, after)
114	case *ast.BasicLit:
115	case *ast.FuncLit:
116		walkBeforeAfter(&n.Type, before, after)
117		walkBeforeAfter(&n.Body, before, after)
118	case *ast.CompositeLit:
119		walkBeforeAfter(&n.Type, before, after)
120		walkBeforeAfter(&n.Elts, before, after)
121	case *ast.ParenExpr:
122		walkBeforeAfter(&n.X, before, after)
123	case *ast.SelectorExpr:
124		walkBeforeAfter(&n.X, before, after)
125	case *ast.IndexExpr:
126		walkBeforeAfter(&n.X, before, after)
127		walkBeforeAfter(&n.Index, before, after)
128	case *ast.IndexListExpr:
129		walkBeforeAfter(&n.X, before, after)
130		walkBeforeAfter(&n.Indices, before, after)
131	case *ast.SliceExpr:
132		walkBeforeAfter(&n.X, before, after)
133		if n.Low != nil {
134			walkBeforeAfter(&n.Low, before, after)
135		}
136		if n.High != nil {
137			walkBeforeAfter(&n.High, before, after)
138		}
139	case *ast.TypeAssertExpr:
140		walkBeforeAfter(&n.X, before, after)
141		walkBeforeAfter(&n.Type, before, after)
142	case *ast.CallExpr:
143		walkBeforeAfter(&n.Fun, before, after)
144		walkBeforeAfter(&n.Args, before, after)
145	case *ast.StarExpr:
146		walkBeforeAfter(&n.X, before, after)
147	case *ast.UnaryExpr:
148		walkBeforeAfter(&n.X, before, after)
149	case *ast.BinaryExpr:
150		walkBeforeAfter(&n.X, before, after)
151		walkBeforeAfter(&n.Y, before, after)
152	case *ast.KeyValueExpr:
153		walkBeforeAfter(&n.Key, before, after)
154		walkBeforeAfter(&n.Value, before, after)
155
156	case *ast.ArrayType:
157		walkBeforeAfter(&n.Len, before, after)
158		walkBeforeAfter(&n.Elt, before, after)
159	case *ast.StructType:
160		walkBeforeAfter(&n.Fields, before, after)
161	case *ast.FuncType:
162		if n.TypeParams != nil {
163			walkBeforeAfter(&n.TypeParams, before, after)
164		}
165		walkBeforeAfter(&n.Params, before, after)
166		if n.Results != nil {
167			walkBeforeAfter(&n.Results, before, after)
168		}
169	case *ast.InterfaceType:
170		walkBeforeAfter(&n.Methods, before, after)
171	case *ast.MapType:
172		walkBeforeAfter(&n.Key, before, after)
173		walkBeforeAfter(&n.Value, before, after)
174	case *ast.ChanType:
175		walkBeforeAfter(&n.Value, before, after)
176
177	case *ast.BadStmt:
178	case *ast.DeclStmt:
179		walkBeforeAfter(&n.Decl, before, after)
180	case *ast.EmptyStmt:
181	case *ast.LabeledStmt:
182		walkBeforeAfter(&n.Stmt, before, after)
183	case *ast.ExprStmt:
184		walkBeforeAfter(&n.X, before, after)
185	case *ast.SendStmt:
186		walkBeforeAfter(&n.Chan, before, after)
187		walkBeforeAfter(&n.Value, before, after)
188	case *ast.IncDecStmt:
189		walkBeforeAfter(&n.X, before, after)
190	case *ast.AssignStmt:
191		walkBeforeAfter(&n.Lhs, before, after)
192		walkBeforeAfter(&n.Rhs, before, after)
193	case *ast.GoStmt:
194		walkBeforeAfter(&n.Call, before, after)
195	case *ast.DeferStmt:
196		walkBeforeAfter(&n.Call, before, after)
197	case *ast.ReturnStmt:
198		walkBeforeAfter(&n.Results, before, after)
199	case *ast.BranchStmt:
200	case *ast.BlockStmt:
201		walkBeforeAfter(&n.List, before, after)
202	case *ast.IfStmt:
203		walkBeforeAfter(&n.Init, before, after)
204		walkBeforeAfter(&n.Cond, before, after)
205		walkBeforeAfter(&n.Body, before, after)
206		walkBeforeAfter(&n.Else, before, after)
207	case *ast.CaseClause:
208		walkBeforeAfter(&n.List, before, after)
209		walkBeforeAfter(&n.Body, before, after)
210	case *ast.SwitchStmt:
211		walkBeforeAfter(&n.Init, before, after)
212		walkBeforeAfter(&n.Tag, before, after)
213		walkBeforeAfter(&n.Body, before, after)
214	case *ast.TypeSwitchStmt:
215		walkBeforeAfter(&n.Init, before, after)
216		walkBeforeAfter(&n.Assign, before, after)
217		walkBeforeAfter(&n.Body, before, after)
218	case *ast.CommClause:
219		walkBeforeAfter(&n.Comm, before, after)
220		walkBeforeAfter(&n.Body, before, after)
221	case *ast.SelectStmt:
222		walkBeforeAfter(&n.Body, before, after)
223	case *ast.ForStmt:
224		walkBeforeAfter(&n.Init, before, after)
225		walkBeforeAfter(&n.Cond, before, after)
226		walkBeforeAfter(&n.Post, before, after)
227		walkBeforeAfter(&n.Body, before, after)
228	case *ast.RangeStmt:
229		walkBeforeAfter(&n.Key, before, after)
230		walkBeforeAfter(&n.Value, before, after)
231		walkBeforeAfter(&n.X, before, after)
232		walkBeforeAfter(&n.Body, before, after)
233
234	case *ast.ImportSpec:
235	case *ast.ValueSpec:
236		walkBeforeAfter(&n.Type, before, after)
237		walkBeforeAfter(&n.Values, before, after)
238		walkBeforeAfter(&n.Names, before, after)
239	case *ast.TypeSpec:
240		if n.TypeParams != nil {
241			walkBeforeAfter(&n.TypeParams, before, after)
242		}
243		walkBeforeAfter(&n.Type, before, after)
244
245	case *ast.BadDecl:
246	case *ast.GenDecl:
247		walkBeforeAfter(&n.Specs, before, after)
248	case *ast.FuncDecl:
249		if n.Recv != nil {
250			walkBeforeAfter(&n.Recv, before, after)
251		}
252		walkBeforeAfter(&n.Type, before, after)
253		if n.Body != nil {
254			walkBeforeAfter(&n.Body, before, after)
255		}
256
257	case *ast.File:
258		walkBeforeAfter(&n.Decls, before, after)
259
260	case *ast.Package:
261		walkBeforeAfter(&n.Files, before, after)
262
263	case []*ast.File:
264		for i := range n {
265			walkBeforeAfter(&n[i], before, after)
266		}
267	case []ast.Decl:
268		for i := range n {
269			walkBeforeAfter(&n[i], before, after)
270		}
271	case []ast.Expr:
272		for i := range n {
273			walkBeforeAfter(&n[i], before, after)
274		}
275	case []*ast.Ident:
276		for i := range n {
277			walkBeforeAfter(&n[i], before, after)
278		}
279	case []ast.Stmt:
280		for i := range n {
281			walkBeforeAfter(&n[i], before, after)
282		}
283	case []ast.Spec:
284		for i := range n {
285			walkBeforeAfter(&n[i], before, after)
286		}
287	}
288	after(x)
289}
290
291// imports reports whether f imports path.
292func imports(f *ast.File, path string) bool {
293	return importSpec(f, path) != nil
294}
295
296// importSpec returns the import spec if f imports path,
297// or nil otherwise.
298func importSpec(f *ast.File, path string) *ast.ImportSpec {
299	for _, s := range f.Imports {
300		if importPath(s) == path {
301			return s
302		}
303	}
304	return nil
305}
306
307// importPath returns the unquoted import path of s,
308// or "" if the path is not properly quoted.
309func importPath(s *ast.ImportSpec) string {
310	t, err := strconv.Unquote(s.Path.Value)
311	if err == nil {
312		return t
313	}
314	return ""
315}
316
317// declImports reports whether gen contains an import of path.
318func declImports(gen *ast.GenDecl, path string) bool {
319	if gen.Tok != token.IMPORT {
320		return false
321	}
322	for _, spec := range gen.Specs {
323		impspec := spec.(*ast.ImportSpec)
324		if importPath(impspec) == path {
325			return true
326		}
327	}
328	return false
329}
330
331// isTopName reports whether n is a top-level unresolved identifier with the given name.
332func isTopName(n ast.Expr, name string) bool {
333	id, ok := n.(*ast.Ident)
334	return ok && id.Name == name && id.Obj == nil
335}
336
337// renameTop renames all references to the top-level name old.
338// It reports whether it makes any changes.
339func renameTop(f *ast.File, old, new string) bool {
340	var fixed bool
341
342	// Rename any conflicting imports
343	// (assuming package name is last element of path).
344	for _, s := range f.Imports {
345		if s.Name != nil {
346			if s.Name.Name == old {
347				s.Name.Name = new
348				fixed = true
349			}
350		} else {
351			_, thisName := path.Split(importPath(s))
352			if thisName == old {
353				s.Name = ast.NewIdent(new)
354				fixed = true
355			}
356		}
357	}
358
359	// Rename any top-level declarations.
360	for _, d := range f.Decls {
361		switch d := d.(type) {
362		case *ast.FuncDecl:
363			if d.Recv == nil && d.Name.Name == old {
364				d.Name.Name = new
365				d.Name.Obj.Name = new
366				fixed = true
367			}
368		case *ast.GenDecl:
369			for _, s := range d.Specs {
370				switch s := s.(type) {
371				case *ast.TypeSpec:
372					if s.Name.Name == old {
373						s.Name.Name = new
374						s.Name.Obj.Name = new
375						fixed = true
376					}
377				case *ast.ValueSpec:
378					for _, n := range s.Names {
379						if n.Name == old {
380							n.Name = new
381							n.Obj.Name = new
382							fixed = true
383						}
384					}
385				}
386			}
387		}
388	}
389
390	// Rename top-level old to new, both unresolved names
391	// (probably defined in another file) and names that resolve
392	// to a declaration we renamed.
393	walk(f, func(n any) {
394		id, ok := n.(*ast.Ident)
395		if ok && isTopName(id, old) {
396			id.Name = new
397			fixed = true
398		}
399		if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
400			id.Name = id.Obj.Name
401			fixed = true
402		}
403	})
404
405	return fixed
406}
407
408// matchLen returns the length of the longest prefix shared by x and y.
409func matchLen(x, y string) int {
410	i := 0
411	for i < len(x) && i < len(y) && x[i] == y[i] {
412		i++
413	}
414	return i
415}
416
417// addImport adds the import path to the file f, if absent.
418func addImport(f *ast.File, ipath string) (added bool) {
419	if imports(f, ipath) {
420		return false
421	}
422
423	// Determine name of import.
424	// Assume added imports follow convention of using last element.
425	_, name := path.Split(ipath)
426
427	// Rename any conflicting top-level references from name to name_.
428	renameTop(f, name, name+"_")
429
430	newImport := &ast.ImportSpec{
431		Path: &ast.BasicLit{
432			Kind:  token.STRING,
433			Value: strconv.Quote(ipath),
434		},
435	}
436
437	// Find an import decl to add to.
438	var (
439		bestMatch  = -1
440		lastImport = -1
441		impDecl    *ast.GenDecl
442		impIndex   = -1
443	)
444	for i, decl := range f.Decls {
445		gen, ok := decl.(*ast.GenDecl)
446		if ok && gen.Tok == token.IMPORT {
447			lastImport = i
448			// Do not add to import "C", to avoid disrupting the
449			// association with its doc comment, breaking cgo.
450			if declImports(gen, "C") {
451				continue
452			}
453
454			// Compute longest shared prefix with imports in this block.
455			for j, spec := range gen.Specs {
456				impspec := spec.(*ast.ImportSpec)
457				n := matchLen(importPath(impspec), ipath)
458				if n > bestMatch {
459					bestMatch = n
460					impDecl = gen
461					impIndex = j
462				}
463			}
464		}
465	}
466
467	// If no import decl found, add one after the last import.
468	if impDecl == nil {
469		impDecl = &ast.GenDecl{
470			Tok: token.IMPORT,
471		}
472		f.Decls = append(f.Decls, nil)
473		copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
474		f.Decls[lastImport+1] = impDecl
475	}
476
477	// Ensure the import decl has parentheses, if needed.
478	if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
479		impDecl.Lparen = impDecl.Pos()
480	}
481
482	insertAt := impIndex + 1
483	if insertAt == 0 {
484		insertAt = len(impDecl.Specs)
485	}
486	impDecl.Specs = append(impDecl.Specs, nil)
487	copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
488	impDecl.Specs[insertAt] = newImport
489	if insertAt > 0 {
490		// Assign same position as the previous import,
491		// so that the sorter sees it as being in the same block.
492		prev := impDecl.Specs[insertAt-1]
493		newImport.Path.ValuePos = prev.Pos()
494		newImport.EndPos = prev.Pos()
495	}
496
497	f.Imports = append(f.Imports, newImport)
498	return true
499}
500
501// deleteImport deletes the import path from the file f, if present.
502func deleteImport(f *ast.File, path string) (deleted bool) {
503	oldImport := importSpec(f, path)
504
505	// Find the import node that imports path, if any.
506	for i, decl := range f.Decls {
507		gen, ok := decl.(*ast.GenDecl)
508		if !ok || gen.Tok != token.IMPORT {
509			continue
510		}
511		for j, spec := range gen.Specs {
512			impspec := spec.(*ast.ImportSpec)
513			if oldImport != impspec {
514				continue
515			}
516
517			// We found an import spec that imports path.
518			// Delete it.
519			deleted = true
520			copy(gen.Specs[j:], gen.Specs[j+1:])
521			gen.Specs = gen.Specs[:len(gen.Specs)-1]
522
523			// If this was the last import spec in this decl,
524			// delete the decl, too.
525			if len(gen.Specs) == 0 {
526				copy(f.Decls[i:], f.Decls[i+1:])
527				f.Decls = f.Decls[:len(f.Decls)-1]
528			} else if len(gen.Specs) == 1 {
529				gen.Lparen = token.NoPos // drop parens
530			}
531			if j > 0 {
532				// We deleted an entry but now there will be
533				// a blank line-sized hole where the import was.
534				// Close the hole by making the previous
535				// import appear to "end" where this one did.
536				gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
537			}
538			break
539		}
540	}
541
542	// Delete it from f.Imports.
543	for i, imp := range f.Imports {
544		if imp == oldImport {
545			copy(f.Imports[i:], f.Imports[i+1:])
546			f.Imports = f.Imports[:len(f.Imports)-1]
547			break
548		}
549	}
550
551	return
552}
553
554// rewriteImport rewrites any import of path oldPath to path newPath.
555func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
556	for _, imp := range f.Imports {
557		if importPath(imp) == oldPath {
558			rewrote = true
559			// record old End, because the default is to compute
560			// it using the length of imp.Path.Value.
561			imp.EndPos = imp.End()
562			imp.Path.Value = strconv.Quote(newPath)
563		}
564	}
565	return
566}
567