1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"bytes"
9	"flag"
10	"fmt"
11	"go/ast"
12	"go/format"
13	"go/parser"
14	"go/scanner"
15	"go/token"
16	"go/version"
17	"internal/diff"
18	"io"
19	"io/fs"
20	"os"
21	"path/filepath"
22	"sort"
23	"strings"
24
25	"cmd/internal/telemetry/counter"
26)
27
28var (
29	fset     = token.NewFileSet()
30	exitCode = 0
31)
32
33var allowedRewrites = flag.String("r", "",
34	"restrict the rewrites to this comma-separated list")
35
36var forceRewrites = flag.String("force", "",
37	"force these fixes to run even if the code looks updated")
38
39var allowed, force map[string]bool
40
41var (
42	doDiff    = flag.Bool("diff", false, "display diffs instead of rewriting files")
43	goVersion = flag.String("go", "", "go language version for files")
44)
45
46// enable for debugging fix failures
47const debug = false // display incorrectly reformatted source and exit
48
49func usage() {
50	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
51	flag.PrintDefaults()
52	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
53	sort.Sort(byName(fixes))
54	for _, f := range fixes {
55		if f.disabled {
56			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
57		} else {
58			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
59		}
60		desc := strings.TrimSpace(f.desc)
61		desc = strings.ReplaceAll(desc, "\n", "\n\t")
62		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
63	}
64	os.Exit(2)
65}
66
67func main() {
68	counter.Open()
69	flag.Usage = usage
70	flag.Parse()
71	counter.Inc("fix/invocations")
72	counter.CountFlags("fix/flag:", *flag.CommandLine)
73
74	if !version.IsValid(*goVersion) {
75		report(fmt.Errorf("invalid -go=%s", *goVersion))
76		os.Exit(exitCode)
77	}
78
79	sort.Sort(byDate(fixes))
80
81	if *allowedRewrites != "" {
82		allowed = make(map[string]bool)
83		for _, f := range strings.Split(*allowedRewrites, ",") {
84			allowed[f] = true
85		}
86	}
87
88	if *forceRewrites != "" {
89		force = make(map[string]bool)
90		for _, f := range strings.Split(*forceRewrites, ",") {
91			force[f] = true
92		}
93	}
94
95	if flag.NArg() == 0 {
96		if err := processFile("standard input", true); err != nil {
97			report(err)
98		}
99		os.Exit(exitCode)
100	}
101
102	for i := 0; i < flag.NArg(); i++ {
103		path := flag.Arg(i)
104		switch dir, err := os.Stat(path); {
105		case err != nil:
106			report(err)
107		case dir.IsDir():
108			walkDir(path)
109		default:
110			if err := processFile(path, false); err != nil {
111				report(err)
112			}
113		}
114	}
115
116	os.Exit(exitCode)
117}
118
119const parserMode = parser.ParseComments
120
121func gofmtFile(f *ast.File) ([]byte, error) {
122	var buf bytes.Buffer
123	if err := format.Node(&buf, fset, f); err != nil {
124		return nil, err
125	}
126	return buf.Bytes(), nil
127}
128
129func processFile(filename string, useStdin bool) error {
130	var f *os.File
131	var err error
132	var fixlog strings.Builder
133
134	if useStdin {
135		f = os.Stdin
136	} else {
137		f, err = os.Open(filename)
138		if err != nil {
139			return err
140		}
141		defer f.Close()
142	}
143
144	src, err := io.ReadAll(f)
145	if err != nil {
146		return err
147	}
148
149	file, err := parser.ParseFile(fset, filename, src, parserMode)
150	if err != nil {
151		return err
152	}
153
154	// Make sure file is in canonical format.
155	// This "fmt" pseudo-fix cannot be disabled.
156	newSrc, err := gofmtFile(file)
157	if err != nil {
158		return err
159	}
160	if !bytes.Equal(newSrc, src) {
161		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
162		if err != nil {
163			return err
164		}
165		file = newFile
166		fmt.Fprintf(&fixlog, " fmt")
167	}
168
169	// Apply all fixes to file.
170	newFile := file
171	fixed := false
172	for _, fix := range fixes {
173		if allowed != nil && !allowed[fix.name] {
174			continue
175		}
176		if fix.disabled && !force[fix.name] {
177			continue
178		}
179		if fix.f(newFile) {
180			fixed = true
181			fmt.Fprintf(&fixlog, " %s", fix.name)
182
183			// AST changed.
184			// Print and parse, to update any missing scoping
185			// or position information for subsequent fixers.
186			newSrc, err := gofmtFile(newFile)
187			if err != nil {
188				return err
189			}
190			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
191			if err != nil {
192				if debug {
193					fmt.Printf("%s", newSrc)
194					report(err)
195					os.Exit(exitCode)
196				}
197				return err
198			}
199		}
200	}
201	if !fixed {
202		return nil
203	}
204	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
205
206	// Print AST.  We did that after each fix, so this appears
207	// redundant, but it is necessary to generate gofmt-compatible
208	// source code in a few cases. The official gofmt style is the
209	// output of the printer run on a standard AST generated by the parser,
210	// but the source we generated inside the loop above is the
211	// output of the printer run on a mangled AST generated by a fixer.
212	newSrc, err = gofmtFile(newFile)
213	if err != nil {
214		return err
215	}
216
217	if *doDiff {
218		os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
219		return nil
220	}
221
222	if useStdin {
223		os.Stdout.Write(newSrc)
224		return nil
225	}
226
227	return os.WriteFile(f.Name(), newSrc, 0)
228}
229
230func gofmt(n any) string {
231	var gofmtBuf strings.Builder
232	if err := format.Node(&gofmtBuf, fset, n); err != nil {
233		return "<" + err.Error() + ">"
234	}
235	return gofmtBuf.String()
236}
237
238func report(err error) {
239	scanner.PrintError(os.Stderr, err)
240	exitCode = 2
241}
242
243func walkDir(path string) {
244	filepath.WalkDir(path, visitFile)
245}
246
247func visitFile(path string, f fs.DirEntry, err error) error {
248	if err == nil && isGoFile(f) {
249		err = processFile(path, false)
250	}
251	if err != nil {
252		report(err)
253	}
254	return nil
255}
256
257func isGoFile(f fs.DirEntry) bool {
258	// ignore non-Go files
259	name := f.Name()
260	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
261}
262