1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package main 6 7import ( 8 "bytes" 9 "flag" 10 "fmt" 11 "go/ast" 12 "go/format" 13 "go/parser" 14 "go/scanner" 15 "go/token" 16 "go/version" 17 "internal/diff" 18 "io" 19 "io/fs" 20 "os" 21 "path/filepath" 22 "sort" 23 "strings" 24 25 "cmd/internal/telemetry/counter" 26) 27 28var ( 29 fset = token.NewFileSet() 30 exitCode = 0 31) 32 33var allowedRewrites = flag.String("r", "", 34 "restrict the rewrites to this comma-separated list") 35 36var forceRewrites = flag.String("force", "", 37 "force these fixes to run even if the code looks updated") 38 39var allowed, force map[string]bool 40 41var ( 42 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") 43 goVersion = flag.String("go", "", "go language version for files") 44) 45 46// enable for debugging fix failures 47const debug = false // display incorrectly reformatted source and exit 48 49func usage() { 50 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") 51 flag.PrintDefaults() 52 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") 53 sort.Sort(byName(fixes)) 54 for _, f := range fixes { 55 if f.disabled { 56 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name) 57 } else { 58 fmt.Fprintf(os.Stderr, "\n%s\n", f.name) 59 } 60 desc := strings.TrimSpace(f.desc) 61 desc = strings.ReplaceAll(desc, "\n", "\n\t") 62 fmt.Fprintf(os.Stderr, "\t%s\n", desc) 63 } 64 os.Exit(2) 65} 66 67func main() { 68 counter.Open() 69 flag.Usage = usage 70 flag.Parse() 71 counter.Inc("fix/invocations") 72 counter.CountFlags("fix/flag:", *flag.CommandLine) 73 74 if !version.IsValid(*goVersion) { 75 report(fmt.Errorf("invalid -go=%s", *goVersion)) 76 os.Exit(exitCode) 77 } 78 79 sort.Sort(byDate(fixes)) 80 81 if *allowedRewrites != "" { 82 allowed = make(map[string]bool) 83 for _, f := range strings.Split(*allowedRewrites, ",") { 84 allowed[f] = true 85 } 86 } 87 88 if *forceRewrites != "" { 89 force = make(map[string]bool) 90 for _, f := range strings.Split(*forceRewrites, ",") { 91 force[f] = true 92 } 93 } 94 95 if flag.NArg() == 0 { 96 if err := processFile("standard input", true); err != nil { 97 report(err) 98 } 99 os.Exit(exitCode) 100 } 101 102 for i := 0; i < flag.NArg(); i++ { 103 path := flag.Arg(i) 104 switch dir, err := os.Stat(path); { 105 case err != nil: 106 report(err) 107 case dir.IsDir(): 108 walkDir(path) 109 default: 110 if err := processFile(path, false); err != nil { 111 report(err) 112 } 113 } 114 } 115 116 os.Exit(exitCode) 117} 118 119const parserMode = parser.ParseComments 120 121func gofmtFile(f *ast.File) ([]byte, error) { 122 var buf bytes.Buffer 123 if err := format.Node(&buf, fset, f); err != nil { 124 return nil, err 125 } 126 return buf.Bytes(), nil 127} 128 129func processFile(filename string, useStdin bool) error { 130 var f *os.File 131 var err error 132 var fixlog strings.Builder 133 134 if useStdin { 135 f = os.Stdin 136 } else { 137 f, err = os.Open(filename) 138 if err != nil { 139 return err 140 } 141 defer f.Close() 142 } 143 144 src, err := io.ReadAll(f) 145 if err != nil { 146 return err 147 } 148 149 file, err := parser.ParseFile(fset, filename, src, parserMode) 150 if err != nil { 151 return err 152 } 153 154 // Make sure file is in canonical format. 155 // This "fmt" pseudo-fix cannot be disabled. 156 newSrc, err := gofmtFile(file) 157 if err != nil { 158 return err 159 } 160 if !bytes.Equal(newSrc, src) { 161 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode) 162 if err != nil { 163 return err 164 } 165 file = newFile 166 fmt.Fprintf(&fixlog, " fmt") 167 } 168 169 // Apply all fixes to file. 170 newFile := file 171 fixed := false 172 for _, fix := range fixes { 173 if allowed != nil && !allowed[fix.name] { 174 continue 175 } 176 if fix.disabled && !force[fix.name] { 177 continue 178 } 179 if fix.f(newFile) { 180 fixed = true 181 fmt.Fprintf(&fixlog, " %s", fix.name) 182 183 // AST changed. 184 // Print and parse, to update any missing scoping 185 // or position information for subsequent fixers. 186 newSrc, err := gofmtFile(newFile) 187 if err != nil { 188 return err 189 } 190 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) 191 if err != nil { 192 if debug { 193 fmt.Printf("%s", newSrc) 194 report(err) 195 os.Exit(exitCode) 196 } 197 return err 198 } 199 } 200 } 201 if !fixed { 202 return nil 203 } 204 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) 205 206 // Print AST. We did that after each fix, so this appears 207 // redundant, but it is necessary to generate gofmt-compatible 208 // source code in a few cases. The official gofmt style is the 209 // output of the printer run on a standard AST generated by the parser, 210 // but the source we generated inside the loop above is the 211 // output of the printer run on a mangled AST generated by a fixer. 212 newSrc, err = gofmtFile(newFile) 213 if err != nil { 214 return err 215 } 216 217 if *doDiff { 218 os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc)) 219 return nil 220 } 221 222 if useStdin { 223 os.Stdout.Write(newSrc) 224 return nil 225 } 226 227 return os.WriteFile(f.Name(), newSrc, 0) 228} 229 230func gofmt(n any) string { 231 var gofmtBuf strings.Builder 232 if err := format.Node(&gofmtBuf, fset, n); err != nil { 233 return "<" + err.Error() + ">" 234 } 235 return gofmtBuf.String() 236} 237 238func report(err error) { 239 scanner.PrintError(os.Stderr, err) 240 exitCode = 2 241} 242 243func walkDir(path string) { 244 filepath.WalkDir(path, visitFile) 245} 246 247func visitFile(path string, f fs.DirEntry, err error) error { 248 if err == nil && isGoFile(f) { 249 err = processFile(path, false) 250 } 251 if err != nil { 252 report(err) 253 } 254 return nil 255} 256 257func isGoFile(f fs.DirEntry) bool { 258 // ignore non-Go files 259 name := f.Name() 260 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") 261} 262