1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build ignore
6
7// Note: this program must be run in this directory.
8//   go run mknode.go
9
10package main
11
12import (
13	"bytes"
14	"fmt"
15	"go/ast"
16	"go/format"
17	"go/parser"
18	"go/token"
19	"io/fs"
20	"log"
21	"os"
22	"sort"
23	"strings"
24)
25
26var fset = token.NewFileSet()
27
28var buf bytes.Buffer
29
30// concreteNodes contains all concrete types in the package that implement Node
31// (except for the mini* types).
32var concreteNodes []*ast.TypeSpec
33
34// interfaceNodes contains all interface types in the package that implement Node.
35var interfaceNodes []*ast.TypeSpec
36
37// mini contains the embeddable mini types (miniNode, miniExpr, and miniStmt).
38var mini = map[string]*ast.TypeSpec{}
39
40// implementsNode reports whether the type t is one which represents a Node
41// in the AST.
42func implementsNode(t ast.Expr) bool {
43	id, ok := t.(*ast.Ident)
44	if !ok {
45		return false // only named types
46	}
47	for _, ts := range interfaceNodes {
48		if ts.Name.Name == id.Name {
49			return true
50		}
51	}
52	for _, ts := range concreteNodes {
53		if ts.Name.Name == id.Name {
54			return true
55		}
56	}
57	return false
58}
59
60func isMini(t ast.Expr) bool {
61	id, ok := t.(*ast.Ident)
62	return ok && mini[id.Name] != nil
63}
64
65func isNamedType(t ast.Expr, name string) bool {
66	if id, ok := t.(*ast.Ident); ok {
67		if id.Name == name {
68			return true
69		}
70	}
71	return false
72}
73
74func main() {
75	fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
76	fmt.Fprintln(&buf)
77	fmt.Fprintln(&buf, "package ir")
78	fmt.Fprintln(&buf)
79	fmt.Fprintln(&buf, `import "fmt"`)
80
81	filter := func(file fs.FileInfo) bool {
82		return !strings.HasPrefix(file.Name(), "mknode")
83	}
84	pkgs, err := parser.ParseDir(fset, ".", filter, 0)
85	if err != nil {
86		panic(err)
87	}
88	pkg := pkgs["ir"]
89
90	// Find all the mini types. These let us determine which
91	// concrete types implement Node, so we need to find them first.
92	for _, f := range pkg.Files {
93		for _, d := range f.Decls {
94			g, ok := d.(*ast.GenDecl)
95			if !ok {
96				continue
97			}
98			for _, s := range g.Specs {
99				t, ok := s.(*ast.TypeSpec)
100				if !ok {
101					continue
102				}
103				if strings.HasPrefix(t.Name.Name, "mini") {
104					mini[t.Name.Name] = t
105					// Double-check that it is or embeds miniNode.
106					if t.Name.Name != "miniNode" {
107						s := t.Type.(*ast.StructType)
108						if !isNamedType(s.Fields.List[0].Type, "miniNode") {
109							panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
110						}
111					}
112				}
113			}
114		}
115	}
116
117	// Find all the declarations of concrete types that implement Node.
118	for _, f := range pkg.Files {
119		for _, d := range f.Decls {
120			g, ok := d.(*ast.GenDecl)
121			if !ok {
122				continue
123			}
124			for _, s := range g.Specs {
125				t, ok := s.(*ast.TypeSpec)
126				if !ok {
127					continue
128				}
129				if strings.HasPrefix(t.Name.Name, "mini") {
130					// We don't treat the mini types as
131					// concrete implementations of Node
132					// (even though they are) because
133					// we only use them by embedding them.
134					continue
135				}
136				if isConcreteNode(t) {
137					concreteNodes = append(concreteNodes, t)
138				}
139				if isInterfaceNode(t) {
140					interfaceNodes = append(interfaceNodes, t)
141				}
142			}
143		}
144	}
145	// Sort for deterministic output.
146	sort.Slice(concreteNodes, func(i, j int) bool {
147		return concreteNodes[i].Name.Name < concreteNodes[j].Name.Name
148	})
149	// Generate code for each concrete type.
150	for _, t := range concreteNodes {
151		processType(t)
152	}
153	// Add some helpers.
154	generateHelpers()
155
156	// Format and write output.
157	out, err := format.Source(buf.Bytes())
158	if err != nil {
159		// write out mangled source so we can see the bug.
160		out = buf.Bytes()
161	}
162	err = os.WriteFile("node_gen.go", out, 0666)
163	if err != nil {
164		log.Fatal(err)
165	}
166}
167
168// isConcreteNode reports whether the type t is a concrete type
169// implementing Node.
170func isConcreteNode(t *ast.TypeSpec) bool {
171	s, ok := t.Type.(*ast.StructType)
172	if !ok {
173		return false
174	}
175	for _, f := range s.Fields.List {
176		if isMini(f.Type) {
177			return true
178		}
179	}
180	return false
181}
182
183// isInterfaceNode reports whether the type t is an interface type
184// implementing Node (including Node itself).
185func isInterfaceNode(t *ast.TypeSpec) bool {
186	s, ok := t.Type.(*ast.InterfaceType)
187	if !ok {
188		return false
189	}
190	if t.Name.Name == "Node" {
191		return true
192	}
193	if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
194		// These we exempt from consideration (fields of
195		// this type don't need to be walked or copied).
196		return false
197	}
198
199	// Look for embedded Node type.
200	// Note that this doesn't handle multi-level embedding, but
201	// we have none of that at the moment.
202	for _, f := range s.Methods.List {
203		if len(f.Names) != 0 {
204			continue
205		}
206		if isNamedType(f.Type, "Node") {
207			return true
208		}
209	}
210	return false
211}
212
213func processType(t *ast.TypeSpec) {
214	name := t.Name.Name
215	fmt.Fprintf(&buf, "\n")
216	fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
217
218	switch name {
219	case "Name", "Func":
220		// Too specialized to automate.
221		return
222	}
223
224	s := t.Type.(*ast.StructType)
225	fields := s.Fields.List
226
227	// Expand any embedded fields.
228	for i := 0; i < len(fields); i++ {
229		f := fields[i]
230		if len(f.Names) != 0 {
231			continue // not embedded
232		}
233		if isMini(f.Type) {
234			// Insert the fields of the embedded type into the main type.
235			// (It would be easier just to append, but inserting in place
236			// matches the old mknode behavior.)
237			ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
238			var f2 []*ast.Field
239			f2 = append(f2, fields[:i]...)
240			f2 = append(f2, ss.Fields.List...)
241			f2 = append(f2, fields[i+1:]...)
242			fields = f2
243			i--
244			continue
245		} else if isNamedType(f.Type, "origNode") {
246			// Ignore this field
247			copy(fields[i:], fields[i+1:])
248			fields = fields[:len(fields)-1]
249			i--
250			continue
251		} else {
252			panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
253		}
254	}
255	// Process fields.
256	var copyBody strings.Builder
257	var doChildrenBody strings.Builder
258	var editChildrenBody strings.Builder
259	var editChildrenWithHiddenBody strings.Builder
260	for _, f := range fields {
261		names := f.Names
262		ft := f.Type
263		hidden := false
264		if f.Tag != nil {
265			tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
266			if strings.HasPrefix(tag, "mknode:") {
267				if tag[7:] == "\"-\"" {
268					if !isNamedType(ft, "Node") {
269						continue
270					}
271					hidden = true
272				} else {
273					panic(fmt.Sprintf("unexpected tag value: %s", tag))
274				}
275			}
276		}
277		if isNamedType(ft, "Nodes") {
278			// Nodes == []Node
279			ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
280		}
281		isSlice := false
282		if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
283			isSlice = true
284			ft = a.Elt
285		}
286		isPtr := false
287		if p, ok := ft.(*ast.StarExpr); ok {
288			isPtr = true
289			ft = p.X
290		}
291		if !implementsNode(ft) {
292			continue
293		}
294		for _, name := range names {
295			ptr := ""
296			if isPtr {
297				ptr = "*"
298			}
299			if isSlice {
300				fmt.Fprintf(&editChildrenWithHiddenBody,
301					"edit%ss(n.%s, edit)\n", ft, name)
302			} else {
303				fmt.Fprintf(&editChildrenWithHiddenBody,
304					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
305			}
306			if hidden {
307				continue
308			}
309			if isSlice {
310				fmt.Fprintf(&copyBody, "c.%s = copy%ss(c.%s)\n", name, ft, name)
311				fmt.Fprintf(&doChildrenBody,
312					"if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
313				fmt.Fprintf(&editChildrenBody,
314					"edit%ss(n.%s, edit)\n", ft, name)
315			} else {
316				fmt.Fprintf(&doChildrenBody,
317					"if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
318				fmt.Fprintf(&editChildrenBody,
319					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
320			}
321		}
322	}
323	fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
324	buf.WriteString(copyBody.String())
325	fmt.Fprintf(&buf, "return &c\n}\n")
326	fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
327	buf.WriteString(doChildrenBody.String())
328	fmt.Fprintf(&buf, "return false\n}\n")
329	fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
330	buf.WriteString(editChildrenBody.String())
331	fmt.Fprintf(&buf, "}\n")
332	fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name)
333	buf.WriteString(editChildrenWithHiddenBody.String())
334	fmt.Fprintf(&buf, "}\n")
335}
336
337func generateHelpers() {
338	for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} {
339		ptr := "*"
340		if typ == "Node" {
341			ptr = "" // interfaces don't need *
342		}
343		fmt.Fprintf(&buf, "\n")
344		fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
345		fmt.Fprintf(&buf, "if list == nil { return nil }\n")
346		fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
347		fmt.Fprintf(&buf, "copy(c, list)\n")
348		fmt.Fprintf(&buf, "return c\n")
349		fmt.Fprintf(&buf, "}\n")
350		fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
351		fmt.Fprintf(&buf, "for _, x := range list {\n")
352		fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
353		fmt.Fprintf(&buf, "return true\n")
354		fmt.Fprintf(&buf, "}\n")
355		fmt.Fprintf(&buf, "}\n")
356		fmt.Fprintf(&buf, "return false\n")
357		fmt.Fprintf(&buf, "}\n")
358		fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
359		fmt.Fprintf(&buf, "for i, x := range list {\n")
360		fmt.Fprintf(&buf, "if x != nil {\n")
361		fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
362		fmt.Fprintf(&buf, "}\n")
363		fmt.Fprintf(&buf, "}\n")
364		fmt.Fprintf(&buf, "}\n")
365	}
366}
367