1/* Copyright 2016 The Bazel Authors. All rights reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14*/ 15 16// Go testing support for Bazel. 17// 18// A Go test comprises three packages: 19// 20// 1. An internal test package, compiled from the sources of the library being 21// tested and any _test.go files with the same package name. 22// 2. An external test package, compiled from _test.go files with a package 23// name ending with "_test". 24// 3. A generated main package that imports both packages and initializes the 25// test framework with a list of tests, benchmarks, examples, and fuzz 26// targets read from source files. 27// 28// This action generates the source code for (3). The equivalent code for 29// 'go test' is in $GOROOT/src/cmd/go/internal/load/test.go. 30 31package main 32 33import ( 34 "flag" 35 "fmt" 36 "go/ast" 37 "go/build" 38 "go/doc" 39 "go/parser" 40 "go/token" 41 "os" 42 "sort" 43 "strings" 44 "text/template" 45) 46 47type Import struct { 48 Name string 49 Path string 50} 51 52type TestCase struct { 53 Package string 54 Name string 55} 56 57type Example struct { 58 Package string 59 Name string 60 Output string 61 Unordered bool 62} 63 64// Cases holds template data. 65type Cases struct { 66 Imports []*Import 67 Tests []TestCase 68 Benchmarks []TestCase 69 FuzzTargets []TestCase 70 Examples []Example 71 TestMain string 72 CoverMode string 73 CoverFormat string 74 Pkgname string 75} 76 77// Version returns whether v is a supported Go version (like "go1.18"). 78func (c *Cases) Version(v string) bool { 79 for _, r := range build.Default.ReleaseTags { 80 if v == r { 81 return true 82 } 83 } 84 return false 85} 86 87const testMainTpl = ` 88package main 89 90// This package must be initialized before packages being tested. 91// NOTE: this relies on the order of package initialization, which is the spec 92// is somewhat unclear about-- it only clearly guarantees that imported packages 93// are initialized before their importers, though in practice (and implied) it 94// also respects declaration order, which we're relying on here. 95import "github.com/bazelbuild/rules_go/go/tools/bzltestutil" 96 97import ( 98 "flag" 99 "log" 100 "os" 101 "os/exec" 102{{if .TestMain}} 103 "reflect" 104{{end}} 105 "strconv" 106 "testing" 107 "testing/internal/testdeps" 108 109{{if ne .CoverMode ""}} 110 "github.com/bazelbuild/rules_go/go/tools/coverdata" 111{{end}} 112 113{{range $p := .Imports}} 114 {{$p.Name}} "{{$p.Path}}" 115{{end}} 116) 117 118var allTests = []testing.InternalTest{ 119{{range .Tests}} 120 {"{{.Name}}", {{.Package}}.{{.Name}} }, 121{{end}} 122} 123 124var benchmarks = []testing.InternalBenchmark{ 125{{range .Benchmarks}} 126 {"{{.Name}}", {{.Package}}.{{.Name}} }, 127{{end}} 128} 129 130{{if .Version "go1.18"}} 131var fuzzTargets = []testing.InternalFuzzTarget{ 132{{range .FuzzTargets}} 133 {"{{.Name}}", {{.Package}}.{{.Name}} }, 134{{end}} 135} 136{{end}} 137 138var examples = []testing.InternalExample{ 139{{range .Examples}} 140 {Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} }, 141{{end}} 142} 143 144func testsInShard() []testing.InternalTest { 145 totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS")) 146 if err != nil || totalShards <= 1 { 147 return allTests 148 } 149 file, err := os.Create(os.Getenv("TEST_SHARD_STATUS_FILE")) 150 if err != nil { 151 log.Fatalf("Failed to touch TEST_SHARD_STATUS_FILE: %v", err) 152 } 153 _ = file.Close() 154 shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX")) 155 if err != nil || shardIndex < 0 { 156 return allTests 157 } 158 tests := []testing.InternalTest{} 159 for i, t := range allTests { 160 if i % totalShards == shardIndex { 161 tests = append(tests, t) 162 } 163 } 164 return tests 165} 166 167func main() { 168 if bzltestutil.ShouldWrap() { 169 err := bzltestutil.Wrap("{{.Pkgname}}") 170 if xerr, ok := err.(*exec.ExitError); ok { 171 os.Exit(xerr.ExitCode()) 172 } else if err != nil { 173 log.Print(err) 174 os.Exit(bzltestutil.TestWrapperAbnormalExit) 175 } else { 176 os.Exit(0) 177 } 178 } 179 180 testDeps := 181 {{if eq .CoverFormat "lcov"}} 182 bzltestutil.LcovTestDeps{TestDeps: testdeps.TestDeps{}} 183 {{else}} 184 testdeps.TestDeps{} 185 {{end}} 186 {{if .Version "go1.18"}} 187 m := testing.MainStart(testDeps, testsInShard(), benchmarks, fuzzTargets, examples) 188 {{else}} 189 m := testing.MainStart(testDeps, testsInShard(), benchmarks, examples) 190 {{end}} 191 192 if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" { 193 flag.Lookup("test.run").Value.Set(filter) 194 } 195 196 if failfast := os.Getenv("TESTBRIDGE_TEST_RUNNER_FAIL_FAST"); failfast != "" { 197 flag.Lookup("test.failfast").Value.Set("true") 198 } 199{{if eq .CoverFormat "lcov"}} 200 panicOnExit0Flag := flag.Lookup("test.paniconexit0").Value 201 testDeps.OriginalPanicOnExit = panicOnExit0Flag.(flag.Getter).Get().(bool) 202 // Setting this flag provides a way to run hooks right before testing.M.Run() returns. 203 panicOnExit0Flag.Set("true") 204{{end}} 205{{if ne .CoverMode ""}} 206 if len(coverdata.Counters) > 0 { 207 testing.RegisterCover(testing.Cover{ 208 Mode: "{{ .CoverMode }}", 209 Counters: coverdata.Counters, 210 Blocks: coverdata.Blocks, 211 }) 212 213 if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok { 214 {{if eq .CoverFormat "lcov"}} 215 flag.Lookup("test.coverprofile").Value.Set(coverageDat+".cover") 216 {{else}} 217 flag.Lookup("test.coverprofile").Value.Set(coverageDat) 218 {{end}} 219 } 220 } 221 {{end}} 222 223 {{if not .TestMain}} 224 res := m.Run() 225 {{else}} 226 {{.TestMain}}(m) 227 {{/* See golang.org/issue/34129 and golang.org/cl/219639 */}} 228 res := int(reflect.ValueOf(m).Elem().FieldByName("exitCode").Int()) 229 {{end}} 230 os.Exit(res) 231} 232` 233 234func genTestMain(args []string) error { 235 // Prepare our flags 236 args, _, err := expandParamsFiles(args) 237 if err != nil { 238 return err 239 } 240 imports := multiFlag{} 241 sources := multiFlag{} 242 flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError) 243 goenv := envFlags(flags) 244 out := flags.String("output", "", "output file to write. Defaults to stdout.") 245 coverMode := flags.String("cover_mode", "", "the coverage mode to use") 246 coverFormat := flags.String("cover_format", "", "the coverage report type to generate (go_cover or lcov)") 247 pkgname := flags.String("pkgname", "", "package name of test") 248 flags.Var(&imports, "import", "Packages to import") 249 flags.Var(&sources, "src", "Sources to process for tests") 250 if err := flags.Parse(args); err != nil { 251 return err 252 } 253 if err := goenv.checkFlags(); err != nil { 254 return err 255 } 256 // Process import args 257 importMap := map[string]*Import{} 258 for _, imp := range imports { 259 parts := strings.Split(imp, "=") 260 if len(parts) != 2 { 261 return fmt.Errorf("Invalid import %q specified", imp) 262 } 263 i := &Import{Name: parts[0], Path: parts[1]} 264 importMap[i.Name] = i 265 } 266 // Process source args 267 sourceList := []string{} 268 sourceMap := map[string]string{} 269 for _, s := range sources { 270 parts := strings.Split(s, "=") 271 if len(parts) != 2 { 272 return fmt.Errorf("Invalid source %q specified", s) 273 } 274 sourceList = append(sourceList, parts[1]) 275 sourceMap[parts[1]] = parts[0] 276 } 277 278 // filter our input file list 279 filteredSrcs, err := filterAndSplitFiles(sourceList) 280 if err != nil { 281 return err 282 } 283 goSrcs := filteredSrcs.goSrcs 284 285 outFile := os.Stdout 286 if *out != "" { 287 var err error 288 outFile, err = os.Create(*out) 289 if err != nil { 290 return fmt.Errorf("os.Create(%q): %v", *out, err) 291 } 292 defer outFile.Close() 293 } 294 295 cases := Cases{ 296 CoverFormat: *coverFormat, 297 CoverMode: *coverMode, 298 Pkgname: *pkgname, 299 } 300 301 testFileSet := token.NewFileSet() 302 pkgs := map[string]bool{} 303 for _, f := range goSrcs { 304 parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments) 305 if err != nil { 306 return fmt.Errorf("ParseFile(%q): %v", f.filename, err) 307 } 308 pkg := sourceMap[f.filename] 309 if strings.HasSuffix(parse.Name.String(), "_test") { 310 pkg += "_test" 311 } 312 for _, e := range doc.Examples(parse) { 313 if e.Output == "" && !e.EmptyOutput { 314 continue 315 } 316 cases.Examples = append(cases.Examples, Example{ 317 Name: "Example" + e.Name, 318 Package: pkg, 319 Output: e.Output, 320 Unordered: e.Unordered, 321 }) 322 pkgs[pkg] = true 323 } 324 for _, d := range parse.Decls { 325 fn, ok := d.(*ast.FuncDecl) 326 if !ok { 327 continue 328 } 329 if fn.Recv != nil { 330 continue 331 } 332 if fn.Name.Name == "TestMain" { 333 // TestMain is not, itself, a test 334 pkgs[pkg] = true 335 cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name) 336 continue 337 } 338 339 // Here we check the signature of the Test* function. To 340 // be considered a test: 341 342 // 1. The function should have a single argument. 343 if len(fn.Type.Params.List) != 1 { 344 continue 345 } 346 347 // 2. The function should return nothing. 348 if fn.Type.Results != nil { 349 continue 350 } 351 352 // 3. The only parameter should have a type identified as 353 // *<something>.T 354 starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr) 355 if !ok { 356 continue 357 } 358 selExpr, ok := starExpr.X.(*ast.SelectorExpr) 359 if !ok { 360 continue 361 } 362 363 // We do not descriminate on the referenced type of the 364 // parameter being *testing.T. Instead we assert that it 365 // should be *<something>.T. This is because the import 366 // could have been aliased as a different identifier. 367 368 if strings.HasPrefix(fn.Name.Name, "Test") { 369 if selExpr.Sel.Name != "T" { 370 continue 371 } 372 pkgs[pkg] = true 373 cases.Tests = append(cases.Tests, TestCase{ 374 Package: pkg, 375 Name: fn.Name.Name, 376 }) 377 } 378 if strings.HasPrefix(fn.Name.Name, "Benchmark") { 379 if selExpr.Sel.Name != "B" { 380 continue 381 } 382 pkgs[pkg] = true 383 cases.Benchmarks = append(cases.Benchmarks, TestCase{ 384 Package: pkg, 385 Name: fn.Name.Name, 386 }) 387 } 388 if strings.HasPrefix(fn.Name.Name, "Fuzz") { 389 if selExpr.Sel.Name != "F" { 390 continue 391 } 392 pkgs[pkg] = true 393 cases.FuzzTargets = append(cases.FuzzTargets, TestCase{ 394 Package: pkg, 395 Name: fn.Name.Name, 396 }) 397 } 398 } 399 } 400 401 for name := range importMap { 402 // Set the names for all unused imports to "_" 403 if !pkgs[name] { 404 importMap[name].Name = "_" 405 } 406 cases.Imports = append(cases.Imports, importMap[name]) 407 } 408 sort.Slice(cases.Imports, func(i, j int) bool { 409 return cases.Imports[i].Name < cases.Imports[j].Name 410 }) 411 tpl := template.Must(template.New("source").Parse(testMainTpl)) 412 if err := tpl.Execute(outFile, &cases); err != nil { 413 return fmt.Errorf("template.Execute(%v): %v", cases, err) 414 } 415 return nil 416} 417