xref: /aosp_15_r20/external/bazelbuild-rules_go/go/tools/builders/generate_test_main.go (revision 9bb1b549b6a84214c53be0924760be030e66b93a)
1/* Copyright 2016 The Bazel Authors. All rights reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7   http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14*/
15
16// Go testing support for Bazel.
17//
18// A Go test comprises three packages:
19//
20// 1. An internal test package, compiled from the sources of the library being
21//    tested and any _test.go files with the same package name.
22// 2. An external test package, compiled from _test.go files with a package
23//    name ending with "_test".
24// 3. A generated main package that imports both packages and initializes the
25//    test framework with a list of tests, benchmarks, examples, and fuzz
26//    targets read from source files.
27//
28// This action generates the source code for (3). The equivalent code for
29// 'go test' is in $GOROOT/src/cmd/go/internal/load/test.go.
30
31package main
32
33import (
34	"flag"
35	"fmt"
36	"go/ast"
37	"go/build"
38	"go/doc"
39	"go/parser"
40	"go/token"
41	"os"
42	"sort"
43	"strings"
44	"text/template"
45)
46
47type Import struct {
48	Name string
49	Path string
50}
51
52type TestCase struct {
53	Package string
54	Name    string
55}
56
57type Example struct {
58	Package   string
59	Name      string
60	Output    string
61	Unordered bool
62}
63
64// Cases holds template data.
65type Cases struct {
66	Imports     []*Import
67	Tests       []TestCase
68	Benchmarks  []TestCase
69	FuzzTargets []TestCase
70	Examples    []Example
71	TestMain    string
72	CoverMode   string
73	CoverFormat string
74	Pkgname     string
75}
76
77// Version returns whether v is a supported Go version (like "go1.18").
78func (c *Cases) Version(v string) bool {
79	for _, r := range build.Default.ReleaseTags {
80		if v == r {
81			return true
82		}
83	}
84	return false
85}
86
87const testMainTpl = `
88package main
89
90// This package must be initialized before packages being tested.
91// NOTE: this relies on the order of package initialization, which is the spec
92// is somewhat unclear about-- it only clearly guarantees that imported packages
93// are initialized before their importers, though in practice (and implied) it
94// also respects declaration order, which we're relying on here.
95import "github.com/bazelbuild/rules_go/go/tools/bzltestutil"
96
97import (
98	"flag"
99	"log"
100	"os"
101	"os/exec"
102{{if .TestMain}}
103	"reflect"
104{{end}}
105	"strconv"
106	"testing"
107	"testing/internal/testdeps"
108
109{{if ne .CoverMode ""}}
110	"github.com/bazelbuild/rules_go/go/tools/coverdata"
111{{end}}
112
113{{range $p := .Imports}}
114	{{$p.Name}} "{{$p.Path}}"
115{{end}}
116)
117
118var allTests = []testing.InternalTest{
119{{range .Tests}}
120	{"{{.Name}}", {{.Package}}.{{.Name}} },
121{{end}}
122}
123
124var benchmarks = []testing.InternalBenchmark{
125{{range .Benchmarks}}
126	{"{{.Name}}", {{.Package}}.{{.Name}} },
127{{end}}
128}
129
130{{if .Version "go1.18"}}
131var fuzzTargets = []testing.InternalFuzzTarget{
132{{range .FuzzTargets}}
133  {"{{.Name}}", {{.Package}}.{{.Name}} },
134{{end}}
135}
136{{end}}
137
138var examples = []testing.InternalExample{
139{{range .Examples}}
140	{Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} },
141{{end}}
142}
143
144func testsInShard() []testing.InternalTest {
145	totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS"))
146	if err != nil || totalShards <= 1 {
147		return allTests
148	}
149	file, err := os.Create(os.Getenv("TEST_SHARD_STATUS_FILE"))
150	if err != nil {
151		log.Fatalf("Failed to touch TEST_SHARD_STATUS_FILE: %v", err)
152	}
153	_ = file.Close()
154	shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX"))
155	if err != nil || shardIndex < 0 {
156		return allTests
157	}
158	tests := []testing.InternalTest{}
159	for i, t := range allTests {
160		if i % totalShards == shardIndex {
161			tests = append(tests, t)
162		}
163	}
164	return tests
165}
166
167func main() {
168	if bzltestutil.ShouldWrap() {
169		err := bzltestutil.Wrap("{{.Pkgname}}")
170		if xerr, ok := err.(*exec.ExitError); ok {
171			os.Exit(xerr.ExitCode())
172		} else if err != nil {
173			log.Print(err)
174			os.Exit(bzltestutil.TestWrapperAbnormalExit)
175		} else {
176			os.Exit(0)
177		}
178	}
179
180	testDeps :=
181  {{if eq .CoverFormat "lcov"}}
182		bzltestutil.LcovTestDeps{TestDeps: testdeps.TestDeps{}}
183  {{else}}
184		testdeps.TestDeps{}
185  {{end}}
186  {{if .Version "go1.18"}}
187	m := testing.MainStart(testDeps, testsInShard(), benchmarks, fuzzTargets, examples)
188  {{else}}
189	m := testing.MainStart(testDeps, testsInShard(), benchmarks, examples)
190  {{end}}
191
192	if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" {
193		flag.Lookup("test.run").Value.Set(filter)
194	}
195
196	if failfast := os.Getenv("TESTBRIDGE_TEST_RUNNER_FAIL_FAST"); failfast != "" {
197		flag.Lookup("test.failfast").Value.Set("true")
198	}
199{{if eq .CoverFormat "lcov"}}
200	panicOnExit0Flag := flag.Lookup("test.paniconexit0").Value
201	testDeps.OriginalPanicOnExit = panicOnExit0Flag.(flag.Getter).Get().(bool)
202	// Setting this flag provides a way to run hooks right before testing.M.Run() returns.
203	panicOnExit0Flag.Set("true")
204{{end}}
205{{if ne .CoverMode ""}}
206	if len(coverdata.Counters) > 0 {
207		testing.RegisterCover(testing.Cover{
208			Mode: "{{ .CoverMode }}",
209			Counters: coverdata.Counters,
210			Blocks: coverdata.Blocks,
211		})
212
213		if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok {
214			{{if eq .CoverFormat "lcov"}}
215			flag.Lookup("test.coverprofile").Value.Set(coverageDat+".cover")
216			{{else}}
217			flag.Lookup("test.coverprofile").Value.Set(coverageDat)
218			{{end}}
219		}
220	}
221	{{end}}
222
223	{{if not .TestMain}}
224	res := m.Run()
225	{{else}}
226	{{.TestMain}}(m)
227	{{/* See golang.org/issue/34129 and golang.org/cl/219639 */}}
228	res := int(reflect.ValueOf(m).Elem().FieldByName("exitCode").Int())
229	{{end}}
230	os.Exit(res)
231}
232`
233
234func genTestMain(args []string) error {
235	// Prepare our flags
236	args, _, err := expandParamsFiles(args)
237	if err != nil {
238		return err
239	}
240	imports := multiFlag{}
241	sources := multiFlag{}
242	flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
243	goenv := envFlags(flags)
244	out := flags.String("output", "", "output file to write. Defaults to stdout.")
245	coverMode := flags.String("cover_mode", "", "the coverage mode to use")
246	coverFormat := flags.String("cover_format", "", "the coverage report type to generate (go_cover or lcov)")
247	pkgname := flags.String("pkgname", "", "package name of test")
248	flags.Var(&imports, "import", "Packages to import")
249	flags.Var(&sources, "src", "Sources to process for tests")
250	if err := flags.Parse(args); err != nil {
251		return err
252	}
253	if err := goenv.checkFlags(); err != nil {
254		return err
255	}
256	// Process import args
257	importMap := map[string]*Import{}
258	for _, imp := range imports {
259		parts := strings.Split(imp, "=")
260		if len(parts) != 2 {
261			return fmt.Errorf("Invalid import %q specified", imp)
262		}
263		i := &Import{Name: parts[0], Path: parts[1]}
264		importMap[i.Name] = i
265	}
266	// Process source args
267	sourceList := []string{}
268	sourceMap := map[string]string{}
269	for _, s := range sources {
270		parts := strings.Split(s, "=")
271		if len(parts) != 2 {
272			return fmt.Errorf("Invalid source %q specified", s)
273		}
274		sourceList = append(sourceList, parts[1])
275		sourceMap[parts[1]] = parts[0]
276	}
277
278	// filter our input file list
279	filteredSrcs, err := filterAndSplitFiles(sourceList)
280	if err != nil {
281		return err
282	}
283	goSrcs := filteredSrcs.goSrcs
284
285	outFile := os.Stdout
286	if *out != "" {
287		var err error
288		outFile, err = os.Create(*out)
289		if err != nil {
290			return fmt.Errorf("os.Create(%q): %v", *out, err)
291		}
292		defer outFile.Close()
293	}
294
295	cases := Cases{
296		CoverFormat: *coverFormat,
297		CoverMode:   *coverMode,
298		Pkgname:     *pkgname,
299	}
300
301	testFileSet := token.NewFileSet()
302	pkgs := map[string]bool{}
303	for _, f := range goSrcs {
304		parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
305		if err != nil {
306			return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
307		}
308		pkg := sourceMap[f.filename]
309		if strings.HasSuffix(parse.Name.String(), "_test") {
310			pkg += "_test"
311		}
312		for _, e := range doc.Examples(parse) {
313			if e.Output == "" && !e.EmptyOutput {
314				continue
315			}
316			cases.Examples = append(cases.Examples, Example{
317				Name:      "Example" + e.Name,
318				Package:   pkg,
319				Output:    e.Output,
320				Unordered: e.Unordered,
321			})
322			pkgs[pkg] = true
323		}
324		for _, d := range parse.Decls {
325			fn, ok := d.(*ast.FuncDecl)
326			if !ok {
327				continue
328			}
329			if fn.Recv != nil {
330				continue
331			}
332			if fn.Name.Name == "TestMain" {
333				// TestMain is not, itself, a test
334				pkgs[pkg] = true
335				cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
336				continue
337			}
338
339			// Here we check the signature of the Test* function. To
340			// be considered a test:
341
342			// 1. The function should have a single argument.
343			if len(fn.Type.Params.List) != 1 {
344				continue
345			}
346
347			// 2. The function should return nothing.
348			if fn.Type.Results != nil {
349				continue
350			}
351
352			// 3. The only parameter should have a type identified as
353			//    *<something>.T
354			starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
355			if !ok {
356				continue
357			}
358			selExpr, ok := starExpr.X.(*ast.SelectorExpr)
359			if !ok {
360				continue
361			}
362
363			// We do not descriminate on the referenced type of the
364			// parameter being *testing.T. Instead we assert that it
365			// should be *<something>.T. This is because the import
366			// could have been aliased as a different identifier.
367
368			if strings.HasPrefix(fn.Name.Name, "Test") {
369				if selExpr.Sel.Name != "T" {
370					continue
371				}
372				pkgs[pkg] = true
373				cases.Tests = append(cases.Tests, TestCase{
374					Package: pkg,
375					Name:    fn.Name.Name,
376				})
377			}
378			if strings.HasPrefix(fn.Name.Name, "Benchmark") {
379				if selExpr.Sel.Name != "B" {
380					continue
381				}
382				pkgs[pkg] = true
383				cases.Benchmarks = append(cases.Benchmarks, TestCase{
384					Package: pkg,
385					Name:    fn.Name.Name,
386				})
387			}
388			if strings.HasPrefix(fn.Name.Name, "Fuzz") {
389				if selExpr.Sel.Name != "F" {
390					continue
391				}
392				pkgs[pkg] = true
393				cases.FuzzTargets = append(cases.FuzzTargets, TestCase{
394					Package: pkg,
395					Name:    fn.Name.Name,
396				})
397			}
398		}
399	}
400
401	for name := range importMap {
402		// Set the names for all unused imports to "_"
403		if !pkgs[name] {
404			importMap[name].Name = "_"
405		}
406		cases.Imports = append(cases.Imports, importMap[name])
407	}
408	sort.Slice(cases.Imports, func(i, j int) bool {
409		return cases.Imports[i].Name < cases.Imports[j].Name
410	})
411	tpl := template.Must(template.New("source").Parse(testMainTpl))
412	if err := tpl.Execute(outFile, &cases); err != nil {
413		return fmt.Errorf("template.Execute(%v): %v", cases, err)
414	}
415	return nil
416}
417