1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package astutil
6
7// This file defines utilities for working with source positions.
8
9import (
10	"fmt"
11	"go/ast"
12	"go/token"
13	"sort"
14)
15
16// PathEnclosingInterval returns the node that encloses the source
17// interval [start, end), and all its ancestors up to the AST root.
18//
19// The definition of "enclosing" used by this function considers
20// additional whitespace abutting a node to be enclosed by it.
21// In this example:
22//
23//	z := x + y // add them
24//	     <-A->
25//	    <----B----->
26//
27// the ast.BinaryExpr(+) node is considered to enclose interval B
28// even though its [Pos()..End()) is actually only interval A.
29// This behaviour makes user interfaces more tolerant of imperfect
30// input.
31//
32// This function treats tokens as nodes, though they are not included
33// in the result. e.g. PathEnclosingInterval("+") returns the
34// enclosing ast.BinaryExpr("x + y").
35//
36// If start==end, the 1-char interval following start is used instead.
37//
38// The 'exact' result is true if the interval contains only path[0]
39// and perhaps some adjacent whitespace.  It is false if the interval
40// overlaps multiple children of path[0], or if it contains only
41// interior whitespace of path[0].
42// In this example:
43//
44//	z := x + y // add them
45//	  <--C-->     <---E-->
46//	    ^
47//	    D
48//
49// intervals C, D and E are inexact.  C is contained by the
50// z-assignment statement, because it spans three of its children (:=,
51// x, +).  So too is the 1-char interval D, because it contains only
52// interior whitespace of the assignment.  E is considered interior
53// whitespace of the BlockStmt containing the assignment.
54//
55// The resulting path is never empty; it always contains at least the
56// 'root' *ast.File.  Ideally PathEnclosingInterval would reject
57// intervals that lie wholly or partially outside the range of the
58// file, but unfortunately ast.File records only the token.Pos of
59// the 'package' keyword, but not of the start of the file itself.
60func PathEnclosingInterval(root *ast.File, start, end token.Pos) (path []ast.Node, exact bool) {
61	// fmt.Printf("EnclosingInterval %d %d\n", start, end) // debugging
62
63	// Precondition: node.[Pos..End) and adjoining whitespace contain [start, end).
64	var visit func(node ast.Node) bool
65	visit = func(node ast.Node) bool {
66		path = append(path, node)
67
68		nodePos := node.Pos()
69		nodeEnd := node.End()
70
71		// fmt.Printf("visit(%T, %d, %d)\n", node, nodePos, nodeEnd) // debugging
72
73		// Intersect [start, end) with interval of node.
74		if start < nodePos {
75			start = nodePos
76		}
77		if end > nodeEnd {
78			end = nodeEnd
79		}
80
81		// Find sole child that contains [start, end).
82		children := childrenOf(node)
83		l := len(children)
84		for i, child := range children {
85			// [childPos, childEnd) is unaugmented interval of child.
86			childPos := child.Pos()
87			childEnd := child.End()
88
89			// [augPos, augEnd) is whitespace-augmented interval of child.
90			augPos := childPos
91			augEnd := childEnd
92			if i > 0 {
93				augPos = children[i-1].End() // start of preceding whitespace
94			}
95			if i < l-1 {
96				nextChildPos := children[i+1].Pos()
97				// Does [start, end) lie between child and next child?
98				if start >= augEnd && end <= nextChildPos {
99					return false // inexact match
100				}
101				augEnd = nextChildPos // end of following whitespace
102			}
103
104			// fmt.Printf("\tchild %d: [%d..%d)\tcontains interval [%d..%d)?\n",
105			// 	i, augPos, augEnd, start, end) // debugging
106
107			// Does augmented child strictly contain [start, end)?
108			if augPos <= start && end <= augEnd {
109				_, isToken := child.(tokenNode)
110				return isToken || visit(child)
111			}
112
113			// Does [start, end) overlap multiple children?
114			// i.e. left-augmented child contains start
115			// but LR-augmented child does not contain end.
116			if start < childEnd && end > augEnd {
117				break
118			}
119		}
120
121		// No single child contained [start, end),
122		// so node is the result.  Is it exact?
123
124		// (It's tempting to put this condition before the
125		// child loop, but it gives the wrong result in the
126		// case where a node (e.g. ExprStmt) and its sole
127		// child have equal intervals.)
128		if start == nodePos && end == nodeEnd {
129			return true // exact match
130		}
131
132		return false // inexact: overlaps multiple children
133	}
134
135	// Ensure [start,end) is nondecreasing.
136	if start > end {
137		start, end = end, start
138	}
139
140	if start < root.End() && end > root.Pos() {
141		if start == end {
142			end = start + 1 // empty interval => interval of size 1
143		}
144		exact = visit(root)
145
146		// Reverse the path:
147		for i, l := 0, len(path); i < l/2; i++ {
148			path[i], path[l-1-i] = path[l-1-i], path[i]
149		}
150	} else {
151		// Selection lies within whitespace preceding the
152		// first (or following the last) declaration in the file.
153		// The result nonetheless always includes the ast.File.
154		path = append(path, root)
155	}
156
157	return
158}
159
160// tokenNode is a dummy implementation of ast.Node for a single token.
161// They are used transiently by PathEnclosingInterval but never escape
162// this package.
163type tokenNode struct {
164	pos token.Pos
165	end token.Pos
166}
167
168func (n tokenNode) Pos() token.Pos {
169	return n.pos
170}
171
172func (n tokenNode) End() token.Pos {
173	return n.end
174}
175
176func tok(pos token.Pos, len int) ast.Node {
177	return tokenNode{pos, pos + token.Pos(len)}
178}
179
180// childrenOf returns the direct non-nil children of ast.Node n.
181// It may include fake ast.Node implementations for bare tokens.
182// it is not safe to call (e.g.) ast.Walk on such nodes.
183func childrenOf(n ast.Node) []ast.Node {
184	var children []ast.Node
185
186	// First add nodes for all true subtrees.
187	ast.Inspect(n, func(node ast.Node) bool {
188		if node == n { // push n
189			return true // recur
190		}
191		if node != nil { // push child
192			children = append(children, node)
193		}
194		return false // no recursion
195	})
196
197	// Then add fake Nodes for bare tokens.
198	switch n := n.(type) {
199	case *ast.ArrayType:
200		children = append(children,
201			tok(n.Lbrack, len("[")),
202			tok(n.Elt.End(), len("]")))
203
204	case *ast.AssignStmt:
205		children = append(children,
206			tok(n.TokPos, len(n.Tok.String())))
207
208	case *ast.BasicLit:
209		children = append(children,
210			tok(n.ValuePos, len(n.Value)))
211
212	case *ast.BinaryExpr:
213		children = append(children, tok(n.OpPos, len(n.Op.String())))
214
215	case *ast.BlockStmt:
216		children = append(children,
217			tok(n.Lbrace, len("{")),
218			tok(n.Rbrace, len("}")))
219
220	case *ast.BranchStmt:
221		children = append(children,
222			tok(n.TokPos, len(n.Tok.String())))
223
224	case *ast.CallExpr:
225		children = append(children,
226			tok(n.Lparen, len("(")),
227			tok(n.Rparen, len(")")))
228		if n.Ellipsis != 0 {
229			children = append(children, tok(n.Ellipsis, len("...")))
230		}
231
232	case *ast.CaseClause:
233		if n.List == nil {
234			children = append(children,
235				tok(n.Case, len("default")))
236		} else {
237			children = append(children,
238				tok(n.Case, len("case")))
239		}
240		children = append(children, tok(n.Colon, len(":")))
241
242	case *ast.ChanType:
243		switch n.Dir {
244		case ast.RECV:
245			children = append(children, tok(n.Begin, len("<-chan")))
246		case ast.SEND:
247			children = append(children, tok(n.Begin, len("chan<-")))
248		case ast.RECV | ast.SEND:
249			children = append(children, tok(n.Begin, len("chan")))
250		}
251
252	case *ast.CommClause:
253		if n.Comm == nil {
254			children = append(children,
255				tok(n.Case, len("default")))
256		} else {
257			children = append(children,
258				tok(n.Case, len("case")))
259		}
260		children = append(children, tok(n.Colon, len(":")))
261
262	case *ast.Comment:
263		// nop
264
265	case *ast.CommentGroup:
266		// nop
267
268	case *ast.CompositeLit:
269		children = append(children,
270			tok(n.Lbrace, len("{")),
271			tok(n.Rbrace, len("{")))
272
273	case *ast.DeclStmt:
274		// nop
275
276	case *ast.DeferStmt:
277		children = append(children,
278			tok(n.Defer, len("defer")))
279
280	case *ast.Ellipsis:
281		children = append(children,
282			tok(n.Ellipsis, len("...")))
283
284	case *ast.EmptyStmt:
285		// nop
286
287	case *ast.ExprStmt:
288		// nop
289
290	case *ast.Field:
291		// TODO(adonovan): Field.{Doc,Comment,Tag}?
292
293	case *ast.FieldList:
294		children = append(children,
295			tok(n.Opening, len("(")), // or len("[")
296			tok(n.Closing, len(")"))) // or len("]")
297
298	case *ast.File:
299		// TODO test: Doc
300		children = append(children,
301			tok(n.Package, len("package")))
302
303	case *ast.ForStmt:
304		children = append(children,
305			tok(n.For, len("for")))
306
307	case *ast.FuncDecl:
308		// TODO(adonovan): FuncDecl.Comment?
309
310		// Uniquely, FuncDecl breaks the invariant that
311		// preorder traversal yields tokens in lexical order:
312		// in fact, FuncDecl.Recv precedes FuncDecl.Type.Func.
313		//
314		// As a workaround, we inline the case for FuncType
315		// here and order things correctly.
316		//
317		children = nil // discard ast.Walk(FuncDecl) info subtrees
318		children = append(children, tok(n.Type.Func, len("func")))
319		if n.Recv != nil {
320			children = append(children, n.Recv)
321		}
322		children = append(children, n.Name)
323		if tparams := n.Type.TypeParams; tparams != nil {
324			children = append(children, tparams)
325		}
326		if n.Type.Params != nil {
327			children = append(children, n.Type.Params)
328		}
329		if n.Type.Results != nil {
330			children = append(children, n.Type.Results)
331		}
332		if n.Body != nil {
333			children = append(children, n.Body)
334		}
335
336	case *ast.FuncLit:
337		// nop
338
339	case *ast.FuncType:
340		if n.Func != 0 {
341			children = append(children,
342				tok(n.Func, len("func")))
343		}
344
345	case *ast.GenDecl:
346		children = append(children,
347			tok(n.TokPos, len(n.Tok.String())))
348		if n.Lparen != 0 {
349			children = append(children,
350				tok(n.Lparen, len("(")),
351				tok(n.Rparen, len(")")))
352		}
353
354	case *ast.GoStmt:
355		children = append(children,
356			tok(n.Go, len("go")))
357
358	case *ast.Ident:
359		children = append(children,
360			tok(n.NamePos, len(n.Name)))
361
362	case *ast.IfStmt:
363		children = append(children,
364			tok(n.If, len("if")))
365
366	case *ast.ImportSpec:
367		// TODO(adonovan): ImportSpec.{Doc,EndPos}?
368
369	case *ast.IncDecStmt:
370		children = append(children,
371			tok(n.TokPos, len(n.Tok.String())))
372
373	case *ast.IndexExpr:
374		children = append(children,
375			tok(n.Lbrack, len("[")),
376			tok(n.Rbrack, len("]")))
377
378	case *ast.IndexListExpr:
379		children = append(children,
380			tok(n.Lbrack, len("[")),
381			tok(n.Rbrack, len("]")))
382
383	case *ast.InterfaceType:
384		children = append(children,
385			tok(n.Interface, len("interface")))
386
387	case *ast.KeyValueExpr:
388		children = append(children,
389			tok(n.Colon, len(":")))
390
391	case *ast.LabeledStmt:
392		children = append(children,
393			tok(n.Colon, len(":")))
394
395	case *ast.MapType:
396		children = append(children,
397			tok(n.Map, len("map")))
398
399	case *ast.ParenExpr:
400		children = append(children,
401			tok(n.Lparen, len("(")),
402			tok(n.Rparen, len(")")))
403
404	case *ast.RangeStmt:
405		children = append(children,
406			tok(n.For, len("for")),
407			tok(n.TokPos, len(n.Tok.String())))
408
409	case *ast.ReturnStmt:
410		children = append(children,
411			tok(n.Return, len("return")))
412
413	case *ast.SelectStmt:
414		children = append(children,
415			tok(n.Select, len("select")))
416
417	case *ast.SelectorExpr:
418		// nop
419
420	case *ast.SendStmt:
421		children = append(children,
422			tok(n.Arrow, len("<-")))
423
424	case *ast.SliceExpr:
425		children = append(children,
426			tok(n.Lbrack, len("[")),
427			tok(n.Rbrack, len("]")))
428
429	case *ast.StarExpr:
430		children = append(children, tok(n.Star, len("*")))
431
432	case *ast.StructType:
433		children = append(children, tok(n.Struct, len("struct")))
434
435	case *ast.SwitchStmt:
436		children = append(children, tok(n.Switch, len("switch")))
437
438	case *ast.TypeAssertExpr:
439		children = append(children,
440			tok(n.Lparen-1, len(".")),
441			tok(n.Lparen, len("(")),
442			tok(n.Rparen, len(")")))
443
444	case *ast.TypeSpec:
445		// TODO(adonovan): TypeSpec.{Doc,Comment}?
446
447	case *ast.TypeSwitchStmt:
448		children = append(children, tok(n.Switch, len("switch")))
449
450	case *ast.UnaryExpr:
451		children = append(children, tok(n.OpPos, len(n.Op.String())))
452
453	case *ast.ValueSpec:
454		// TODO(adonovan): ValueSpec.{Doc,Comment}?
455
456	case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt:
457		// nop
458	}
459
460	// TODO(adonovan): opt: merge the logic of ast.Inspect() into
461	// the switch above so we can make interleaved callbacks for
462	// both Nodes and Tokens in the right order and avoid the need
463	// to sort.
464	sort.Sort(byPos(children))
465
466	return children
467}
468
469type byPos []ast.Node
470
471func (sl byPos) Len() int {
472	return len(sl)
473}
474func (sl byPos) Less(i, j int) bool {
475	return sl[i].Pos() < sl[j].Pos()
476}
477func (sl byPos) Swap(i, j int) {
478	sl[i], sl[j] = sl[j], sl[i]
479}
480
481// NodeDescription returns a description of the concrete type of n suitable
482// for a user interface.
483//
484// TODO(adonovan): in some cases (e.g. Field, FieldList, Ident,
485// StarExpr) we could be much more specific given the path to the AST
486// root.  Perhaps we should do that.
487func NodeDescription(n ast.Node) string {
488	switch n := n.(type) {
489	case *ast.ArrayType:
490		return "array type"
491	case *ast.AssignStmt:
492		return "assignment"
493	case *ast.BadDecl:
494		return "bad declaration"
495	case *ast.BadExpr:
496		return "bad expression"
497	case *ast.BadStmt:
498		return "bad statement"
499	case *ast.BasicLit:
500		return "basic literal"
501	case *ast.BinaryExpr:
502		return fmt.Sprintf("binary %s operation", n.Op)
503	case *ast.BlockStmt:
504		return "block"
505	case *ast.BranchStmt:
506		switch n.Tok {
507		case token.BREAK:
508			return "break statement"
509		case token.CONTINUE:
510			return "continue statement"
511		case token.GOTO:
512			return "goto statement"
513		case token.FALLTHROUGH:
514			return "fall-through statement"
515		}
516	case *ast.CallExpr:
517		if len(n.Args) == 1 && !n.Ellipsis.IsValid() {
518			return "function call (or conversion)"
519		}
520		return "function call"
521	case *ast.CaseClause:
522		return "case clause"
523	case *ast.ChanType:
524		return "channel type"
525	case *ast.CommClause:
526		return "communication clause"
527	case *ast.Comment:
528		return "comment"
529	case *ast.CommentGroup:
530		return "comment group"
531	case *ast.CompositeLit:
532		return "composite literal"
533	case *ast.DeclStmt:
534		return NodeDescription(n.Decl) + " statement"
535	case *ast.DeferStmt:
536		return "defer statement"
537	case *ast.Ellipsis:
538		return "ellipsis"
539	case *ast.EmptyStmt:
540		return "empty statement"
541	case *ast.ExprStmt:
542		return "expression statement"
543	case *ast.Field:
544		// Can be any of these:
545		// struct {x, y int}  -- struct field(s)
546		// struct {T}         -- anon struct field
547		// interface {I}      -- interface embedding
548		// interface {f()}    -- interface method
549		// func (A) func(B) C -- receiver, param(s), result(s)
550		return "field/method/parameter"
551	case *ast.FieldList:
552		return "field/method/parameter list"
553	case *ast.File:
554		return "source file"
555	case *ast.ForStmt:
556		return "for loop"
557	case *ast.FuncDecl:
558		return "function declaration"
559	case *ast.FuncLit:
560		return "function literal"
561	case *ast.FuncType:
562		return "function type"
563	case *ast.GenDecl:
564		switch n.Tok {
565		case token.IMPORT:
566			return "import declaration"
567		case token.CONST:
568			return "constant declaration"
569		case token.TYPE:
570			return "type declaration"
571		case token.VAR:
572			return "variable declaration"
573		}
574	case *ast.GoStmt:
575		return "go statement"
576	case *ast.Ident:
577		return "identifier"
578	case *ast.IfStmt:
579		return "if statement"
580	case *ast.ImportSpec:
581		return "import specification"
582	case *ast.IncDecStmt:
583		if n.Tok == token.INC {
584			return "increment statement"
585		}
586		return "decrement statement"
587	case *ast.IndexExpr:
588		return "index expression"
589	case *ast.IndexListExpr:
590		return "index list expression"
591	case *ast.InterfaceType:
592		return "interface type"
593	case *ast.KeyValueExpr:
594		return "key/value association"
595	case *ast.LabeledStmt:
596		return "statement label"
597	case *ast.MapType:
598		return "map type"
599	case *ast.Package:
600		return "package"
601	case *ast.ParenExpr:
602		return "parenthesized " + NodeDescription(n.X)
603	case *ast.RangeStmt:
604		return "range loop"
605	case *ast.ReturnStmt:
606		return "return statement"
607	case *ast.SelectStmt:
608		return "select statement"
609	case *ast.SelectorExpr:
610		return "selector"
611	case *ast.SendStmt:
612		return "channel send"
613	case *ast.SliceExpr:
614		return "slice expression"
615	case *ast.StarExpr:
616		return "*-operation" // load/store expr or pointer type
617	case *ast.StructType:
618		return "struct type"
619	case *ast.SwitchStmt:
620		return "switch statement"
621	case *ast.TypeAssertExpr:
622		return "type assertion"
623	case *ast.TypeSpec:
624		return "type specification"
625	case *ast.TypeSwitchStmt:
626		return "type switch"
627	case *ast.UnaryExpr:
628		return fmt.Sprintf("unary %s operation", n.Op)
629	case *ast.ValueSpec:
630		return "value specification"
631
632	}
633	panic(fmt.Sprintf("unexpected node type: %T", n))
634}
635