1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Parse input AST and prepare Prog structure.
6
7package main
8
9import (
10	"fmt"
11	"go/ast"
12	"go/format"
13	"go/parser"
14	"go/scanner"
15	"go/token"
16	"os"
17	"strings"
18)
19
20func parse(name string, src []byte, flags parser.Mode) *ast.File {
21	ast1, err := parser.ParseFile(fset, name, src, flags)
22	if err != nil {
23		if list, ok := err.(scanner.ErrorList); ok {
24			// If err is a scanner.ErrorList, its String will print just
25			// the first error and then (+n more errors).
26			// Instead, turn it into a new Error that will return
27			// details for all the errors.
28			for _, e := range list {
29				fmt.Fprintln(os.Stderr, e)
30			}
31			os.Exit(2)
32		}
33		fatalf("parsing %s: %s", name, err)
34	}
35	return ast1
36}
37
38func sourceLine(n ast.Node) int {
39	return fset.Position(n.Pos()).Line
40}
41
42// ParseGo populates f with information learned from the Go source code
43// which was read from the named file. It gathers the C preamble
44// attached to the import "C" comment, a list of references to C.xxx,
45// a list of exported functions, and the actual AST, to be rewritten and
46// printed.
47func (f *File) ParseGo(abspath string, src []byte) {
48	// Two different parses: once with comments, once without.
49	// The printer is not good enough at printing comments in the
50	// right place when we start editing the AST behind its back,
51	// so we use ast1 to look for the doc comments on import "C"
52	// and on exported functions, and we use ast2 for translating
53	// and reprinting.
54	// In cgo mode, we ignore ast2 and just apply edits directly
55	// the text behind ast1. In godefs mode we modify and print ast2.
56	ast1 := parse(abspath, src, parser.SkipObjectResolution|parser.ParseComments)
57	ast2 := parse(abspath, src, parser.SkipObjectResolution)
58
59	f.Package = ast1.Name.Name
60	f.Name = make(map[string]*Name)
61	f.NamePos = make(map[*Name]token.Pos)
62
63	// In ast1, find the import "C" line and get any extra C preamble.
64	sawC := false
65	for _, decl := range ast1.Decls {
66		switch decl := decl.(type) {
67		case *ast.GenDecl:
68			for _, spec := range decl.Specs {
69				s, ok := spec.(*ast.ImportSpec)
70				if !ok || s.Path.Value != `"C"` {
71					continue
72				}
73				sawC = true
74				if s.Name != nil {
75					error_(s.Path.Pos(), `cannot rename import "C"`)
76				}
77				cg := s.Doc
78				if cg == nil && len(decl.Specs) == 1 {
79					cg = decl.Doc
80				}
81				if cg != nil {
82					if strings.ContainsAny(abspath, "\r\n") {
83						// This should have been checked when the file path was first resolved,
84						// but we double check here just to be sure.
85						fatalf("internal error: ParseGo: abspath contains unexpected newline character: %q", abspath)
86					}
87					f.Preamble += fmt.Sprintf("#line %d %q\n", sourceLine(cg), abspath)
88					f.Preamble += commentText(cg) + "\n"
89					f.Preamble += "#line 1 \"cgo-generated-wrapper\"\n"
90				}
91			}
92
93		case *ast.FuncDecl:
94			// Also, reject attempts to declare methods on C.T or *C.T.
95			// (The generated code would otherwise accept this
96			// invalid input; see issue #57926.)
97			if decl.Recv != nil && len(decl.Recv.List) > 0 {
98				recvType := decl.Recv.List[0].Type
99				if recvType != nil {
100					t := recvType
101					if star, ok := unparen(t).(*ast.StarExpr); ok {
102						t = star.X
103					}
104					if sel, ok := unparen(t).(*ast.SelectorExpr); ok {
105						var buf strings.Builder
106						format.Node(&buf, fset, recvType)
107						error_(sel.Pos(), `cannot define new methods on non-local type %s`, &buf)
108					}
109				}
110			}
111		}
112
113	}
114	if !sawC {
115		error_(ast1.Package, `cannot find import "C"`)
116	}
117
118	// In ast2, strip the import "C" line.
119	if *godefs {
120		w := 0
121		for _, decl := range ast2.Decls {
122			d, ok := decl.(*ast.GenDecl)
123			if !ok {
124				ast2.Decls[w] = decl
125				w++
126				continue
127			}
128			ws := 0
129			for _, spec := range d.Specs {
130				s, ok := spec.(*ast.ImportSpec)
131				if !ok || s.Path.Value != `"C"` {
132					d.Specs[ws] = spec
133					ws++
134				}
135			}
136			if ws == 0 {
137				continue
138			}
139			d.Specs = d.Specs[0:ws]
140			ast2.Decls[w] = d
141			w++
142		}
143		ast2.Decls = ast2.Decls[0:w]
144	} else {
145		for _, decl := range ast2.Decls {
146			d, ok := decl.(*ast.GenDecl)
147			if !ok {
148				continue
149			}
150			for _, spec := range d.Specs {
151				if s, ok := spec.(*ast.ImportSpec); ok && s.Path.Value == `"C"` {
152					// Replace "C" with _ "unsafe", to keep program valid.
153					// (Deleting import statement or clause is not safe if it is followed
154					// in the source by an explicit semicolon.)
155					f.Edit.Replace(f.offset(s.Path.Pos()), f.offset(s.Path.End()), `_ "unsafe"`)
156				}
157			}
158		}
159	}
160
161	// Accumulate pointers to uses of C.x.
162	if f.Ref == nil {
163		f.Ref = make([]*Ref, 0, 8)
164	}
165	f.walk(ast2, ctxProg, (*File).validateIdents)
166	f.walk(ast2, ctxProg, (*File).saveExprs)
167
168	// Accumulate exported functions.
169	// The comments are only on ast1 but we need to
170	// save the function bodies from ast2.
171	// The first walk fills in ExpFunc, and the
172	// second walk changes the entries to
173	// refer to ast2 instead.
174	f.walk(ast1, ctxProg, (*File).saveExport)
175	f.walk(ast2, ctxProg, (*File).saveExport2)
176
177	f.Comments = ast1.Comments
178	f.AST = ast2
179}
180
181// Like ast.CommentGroup's Text method but preserves
182// leading blank lines, so that line numbers line up.
183func commentText(g *ast.CommentGroup) string {
184	var pieces []string
185	for _, com := range g.List {
186		c := com.Text
187		// Remove comment markers.
188		// The parser has given us exactly the comment text.
189		switch c[1] {
190		case '/':
191			//-style comment (no newline at the end)
192			c = c[2:] + "\n"
193		case '*':
194			/*-style comment */
195			c = c[2 : len(c)-2]
196		}
197		pieces = append(pieces, c)
198	}
199	return strings.Join(pieces, "")
200}
201
202func (f *File) validateIdents(x interface{}, context astContext) {
203	if x, ok := x.(*ast.Ident); ok {
204		if f.isMangledName(x.Name) {
205			error_(x.Pos(), "identifier %q may conflict with identifiers generated by cgo", x.Name)
206		}
207	}
208}
209
210// Save various references we are going to need later.
211func (f *File) saveExprs(x interface{}, context astContext) {
212	switch x := x.(type) {
213	case *ast.Expr:
214		switch (*x).(type) {
215		case *ast.SelectorExpr:
216			f.saveRef(x, context)
217		}
218	case *ast.CallExpr:
219		f.saveCall(x, context)
220	}
221}
222
223// Save references to C.xxx for later processing.
224func (f *File) saveRef(n *ast.Expr, context astContext) {
225	sel := (*n).(*ast.SelectorExpr)
226	// For now, assume that the only instance of capital C is when
227	// used as the imported package identifier.
228	// The parser should take care of scoping in the future, so
229	// that we will be able to distinguish a "top-level C" from a
230	// local C.
231	if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
232		return
233	}
234	if context == ctxAssign2 {
235		context = ctxExpr
236	}
237	if context == ctxEmbedType {
238		error_(sel.Pos(), "cannot embed C type")
239	}
240	goname := sel.Sel.Name
241	if goname == "errno" {
242		error_(sel.Pos(), "cannot refer to errno directly; see documentation")
243		return
244	}
245	if goname == "_CMalloc" {
246		error_(sel.Pos(), "cannot refer to C._CMalloc; use C.malloc")
247		return
248	}
249	if goname == "malloc" {
250		goname = "_CMalloc"
251	}
252	name := f.Name[goname]
253	if name == nil {
254		name = &Name{
255			Go: goname,
256		}
257		f.Name[goname] = name
258		f.NamePos[name] = sel.Pos()
259	}
260	f.Ref = append(f.Ref, &Ref{
261		Name:    name,
262		Expr:    n,
263		Context: context,
264	})
265}
266
267// Save calls to C.xxx for later processing.
268func (f *File) saveCall(call *ast.CallExpr, context astContext) {
269	sel, ok := call.Fun.(*ast.SelectorExpr)
270	if !ok {
271		return
272	}
273	if l, ok := sel.X.(*ast.Ident); !ok || l.Name != "C" {
274		return
275	}
276	c := &Call{Call: call, Deferred: context == ctxDefer}
277	f.Calls = append(f.Calls, c)
278}
279
280// If a function should be exported add it to ExpFunc.
281func (f *File) saveExport(x interface{}, context astContext) {
282	n, ok := x.(*ast.FuncDecl)
283	if !ok {
284		return
285	}
286
287	if n.Doc == nil {
288		return
289	}
290	for _, c := range n.Doc.List {
291		if !strings.HasPrefix(c.Text, "//export ") {
292			continue
293		}
294
295		name := strings.TrimSpace(c.Text[9:])
296		if name == "" {
297			error_(c.Pos(), "export missing name")
298		}
299
300		if name != n.Name.Name {
301			error_(c.Pos(), "export comment has wrong name %q, want %q", name, n.Name.Name)
302		}
303
304		doc := ""
305		for _, c1 := range n.Doc.List {
306			if c1 != c {
307				doc += c1.Text + "\n"
308			}
309		}
310
311		f.ExpFunc = append(f.ExpFunc, &ExpFunc{
312			Func:    n,
313			ExpName: name,
314			Doc:     doc,
315		})
316		break
317	}
318}
319
320// Make f.ExpFunc[i] point at the Func from this AST instead of the other one.
321func (f *File) saveExport2(x interface{}, context astContext) {
322	n, ok := x.(*ast.FuncDecl)
323	if !ok {
324		return
325	}
326
327	for _, exp := range f.ExpFunc {
328		if exp.Func.Name.Name == n.Name.Name {
329			exp.Func = n
330			break
331		}
332	}
333}
334
335type astContext int
336
337const (
338	ctxProg astContext = iota
339	ctxEmbedType
340	ctxType
341	ctxStmt
342	ctxExpr
343	ctxField
344	ctxParam
345	ctxAssign2 // assignment of a single expression to two variables
346	ctxSwitch
347	ctxTypeSwitch
348	ctxFile
349	ctxDecl
350	ctxSpec
351	ctxDefer
352	ctxCall  // any function call other than ctxCall2
353	ctxCall2 // function call whose result is assigned to two variables
354	ctxSelector
355)
356
357// walk walks the AST x, calling visit(f, x, context) for each node.
358func (f *File) walk(x interface{}, context astContext, visit func(*File, interface{}, astContext)) {
359	visit(f, x, context)
360	switch n := x.(type) {
361	case *ast.Expr:
362		f.walk(*n, context, visit)
363
364	// everything else just recurs
365	default:
366		f.walkUnexpected(x, context, visit)
367
368	case nil:
369
370	// These are ordered and grouped to match ../../go/ast/ast.go
371	case *ast.Field:
372		if len(n.Names) == 0 && context == ctxField {
373			f.walk(&n.Type, ctxEmbedType, visit)
374		} else {
375			f.walk(&n.Type, ctxType, visit)
376		}
377	case *ast.FieldList:
378		for _, field := range n.List {
379			f.walk(field, context, visit)
380		}
381	case *ast.BadExpr:
382	case *ast.Ident:
383	case *ast.Ellipsis:
384		f.walk(&n.Elt, ctxType, visit)
385	case *ast.BasicLit:
386	case *ast.FuncLit:
387		f.walk(n.Type, ctxType, visit)
388		f.walk(n.Body, ctxStmt, visit)
389	case *ast.CompositeLit:
390		f.walk(&n.Type, ctxType, visit)
391		f.walk(n.Elts, ctxExpr, visit)
392	case *ast.ParenExpr:
393		f.walk(&n.X, context, visit)
394	case *ast.SelectorExpr:
395		f.walk(&n.X, ctxSelector, visit)
396	case *ast.IndexExpr:
397		f.walk(&n.X, ctxExpr, visit)
398		f.walk(&n.Index, ctxExpr, visit)
399	case *ast.SliceExpr:
400		f.walk(&n.X, ctxExpr, visit)
401		if n.Low != nil {
402			f.walk(&n.Low, ctxExpr, visit)
403		}
404		if n.High != nil {
405			f.walk(&n.High, ctxExpr, visit)
406		}
407		if n.Max != nil {
408			f.walk(&n.Max, ctxExpr, visit)
409		}
410	case *ast.TypeAssertExpr:
411		f.walk(&n.X, ctxExpr, visit)
412		f.walk(&n.Type, ctxType, visit)
413	case *ast.CallExpr:
414		if context == ctxAssign2 {
415			f.walk(&n.Fun, ctxCall2, visit)
416		} else {
417			f.walk(&n.Fun, ctxCall, visit)
418		}
419		f.walk(n.Args, ctxExpr, visit)
420	case *ast.StarExpr:
421		f.walk(&n.X, context, visit)
422	case *ast.UnaryExpr:
423		f.walk(&n.X, ctxExpr, visit)
424	case *ast.BinaryExpr:
425		f.walk(&n.X, ctxExpr, visit)
426		f.walk(&n.Y, ctxExpr, visit)
427	case *ast.KeyValueExpr:
428		f.walk(&n.Key, ctxExpr, visit)
429		f.walk(&n.Value, ctxExpr, visit)
430
431	case *ast.ArrayType:
432		f.walk(&n.Len, ctxExpr, visit)
433		f.walk(&n.Elt, ctxType, visit)
434	case *ast.StructType:
435		f.walk(n.Fields, ctxField, visit)
436	case *ast.FuncType:
437		if tparams := funcTypeTypeParams(n); tparams != nil {
438			f.walk(tparams, ctxParam, visit)
439		}
440		f.walk(n.Params, ctxParam, visit)
441		if n.Results != nil {
442			f.walk(n.Results, ctxParam, visit)
443		}
444	case *ast.InterfaceType:
445		f.walk(n.Methods, ctxField, visit)
446	case *ast.MapType:
447		f.walk(&n.Key, ctxType, visit)
448		f.walk(&n.Value, ctxType, visit)
449	case *ast.ChanType:
450		f.walk(&n.Value, ctxType, visit)
451
452	case *ast.BadStmt:
453	case *ast.DeclStmt:
454		f.walk(n.Decl, ctxDecl, visit)
455	case *ast.EmptyStmt:
456	case *ast.LabeledStmt:
457		f.walk(n.Stmt, ctxStmt, visit)
458	case *ast.ExprStmt:
459		f.walk(&n.X, ctxExpr, visit)
460	case *ast.SendStmt:
461		f.walk(&n.Chan, ctxExpr, visit)
462		f.walk(&n.Value, ctxExpr, visit)
463	case *ast.IncDecStmt:
464		f.walk(&n.X, ctxExpr, visit)
465	case *ast.AssignStmt:
466		f.walk(n.Lhs, ctxExpr, visit)
467		if len(n.Lhs) == 2 && len(n.Rhs) == 1 {
468			f.walk(n.Rhs, ctxAssign2, visit)
469		} else {
470			f.walk(n.Rhs, ctxExpr, visit)
471		}
472	case *ast.GoStmt:
473		f.walk(n.Call, ctxExpr, visit)
474	case *ast.DeferStmt:
475		f.walk(n.Call, ctxDefer, visit)
476	case *ast.ReturnStmt:
477		f.walk(n.Results, ctxExpr, visit)
478	case *ast.BranchStmt:
479	case *ast.BlockStmt:
480		f.walk(n.List, context, visit)
481	case *ast.IfStmt:
482		f.walk(n.Init, ctxStmt, visit)
483		f.walk(&n.Cond, ctxExpr, visit)
484		f.walk(n.Body, ctxStmt, visit)
485		f.walk(n.Else, ctxStmt, visit)
486	case *ast.CaseClause:
487		if context == ctxTypeSwitch {
488			context = ctxType
489		} else {
490			context = ctxExpr
491		}
492		f.walk(n.List, context, visit)
493		f.walk(n.Body, ctxStmt, visit)
494	case *ast.SwitchStmt:
495		f.walk(n.Init, ctxStmt, visit)
496		f.walk(&n.Tag, ctxExpr, visit)
497		f.walk(n.Body, ctxSwitch, visit)
498	case *ast.TypeSwitchStmt:
499		f.walk(n.Init, ctxStmt, visit)
500		f.walk(n.Assign, ctxStmt, visit)
501		f.walk(n.Body, ctxTypeSwitch, visit)
502	case *ast.CommClause:
503		f.walk(n.Comm, ctxStmt, visit)
504		f.walk(n.Body, ctxStmt, visit)
505	case *ast.SelectStmt:
506		f.walk(n.Body, ctxStmt, visit)
507	case *ast.ForStmt:
508		f.walk(n.Init, ctxStmt, visit)
509		f.walk(&n.Cond, ctxExpr, visit)
510		f.walk(n.Post, ctxStmt, visit)
511		f.walk(n.Body, ctxStmt, visit)
512	case *ast.RangeStmt:
513		f.walk(&n.Key, ctxExpr, visit)
514		f.walk(&n.Value, ctxExpr, visit)
515		f.walk(&n.X, ctxExpr, visit)
516		f.walk(n.Body, ctxStmt, visit)
517
518	case *ast.ImportSpec:
519	case *ast.ValueSpec:
520		f.walk(&n.Type, ctxType, visit)
521		if len(n.Names) == 2 && len(n.Values) == 1 {
522			f.walk(&n.Values[0], ctxAssign2, visit)
523		} else {
524			f.walk(n.Values, ctxExpr, visit)
525		}
526	case *ast.TypeSpec:
527		if tparams := typeSpecTypeParams(n); tparams != nil {
528			f.walk(tparams, ctxParam, visit)
529		}
530		f.walk(&n.Type, ctxType, visit)
531
532	case *ast.BadDecl:
533	case *ast.GenDecl:
534		f.walk(n.Specs, ctxSpec, visit)
535	case *ast.FuncDecl:
536		if n.Recv != nil {
537			f.walk(n.Recv, ctxParam, visit)
538		}
539		f.walk(n.Type, ctxType, visit)
540		if n.Body != nil {
541			f.walk(n.Body, ctxStmt, visit)
542		}
543
544	case *ast.File:
545		f.walk(n.Decls, ctxDecl, visit)
546
547	case *ast.Package:
548		for _, file := range n.Files {
549			f.walk(file, ctxFile, visit)
550		}
551
552	case []ast.Decl:
553		for _, d := range n {
554			f.walk(d, context, visit)
555		}
556	case []ast.Expr:
557		for i := range n {
558			f.walk(&n[i], context, visit)
559		}
560	case []ast.Stmt:
561		for _, s := range n {
562			f.walk(s, context, visit)
563		}
564	case []ast.Spec:
565		for _, s := range n {
566			f.walk(s, context, visit)
567		}
568	}
569}
570
571// If x is of the form (T), unparen returns unparen(T), otherwise it returns x.
572func unparen(x ast.Expr) ast.Expr {
573	if p, isParen := x.(*ast.ParenExpr); isParen {
574		x = unparen(p.X)
575	}
576	return x
577}
578