xref: /aosp_15_r20/external/starlark-go/starlark/interp.go (revision 4947cdc739c985f6d86941e22894f5cefe7c9e9a)
1package starlark
2
3// This file defines the bytecode interpreter.
4
5import (
6	"fmt"
7	"os"
8	"sync/atomic"
9	"unsafe"
10
11	"go.starlark.net/internal/compile"
12	"go.starlark.net/internal/spell"
13	"go.starlark.net/resolve"
14	"go.starlark.net/syntax"
15)
16
17const vmdebug = false // TODO(adonovan): use a bitfield of specific kinds of error.
18
19// TODO(adonovan):
20// - optimize position table.
21// - opt: record MaxIterStack during compilation and preallocate the stack.
22
23func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
24	// Postcondition: args is not mutated. This is stricter than required by Callable,
25	// but allows CALL to avoid a copy.
26
27	if !resolve.AllowRecursion {
28		// detect recursion
29		for _, fr := range thread.stack[:len(thread.stack)-1] {
30			// We look for the same function code,
31			// not function value, otherwise the user could
32			// defeat the check by writing the Y combinator.
33			if frfn, ok := fr.Callable().(*Function); ok && frfn.funcode == fn.funcode {
34				return nil, fmt.Errorf("function %s called recursively", fn.Name())
35			}
36		}
37	}
38
39	f := fn.funcode
40	fr := thread.frameAt(0)
41
42	// Allocate space for stack and locals.
43	// Logically these do not escape from this frame
44	// (See https://github.com/golang/go/issues/20533.)
45	//
46	// This heap allocation looks expensive, but I was unable to get
47	// more than 1% real time improvement in a large alloc-heavy
48	// benchmark (in which this alloc was 8% of alloc-bytes)
49	// by allocating space for 8 Values in each frame, or
50	// by allocating stack by slicing an array held by the Thread
51	// that is expanded in chunks of min(k, nspace), for k=256 or 1024.
52	nlocals := len(f.Locals)
53	nspace := nlocals + f.MaxStack
54	space := make([]Value, nspace)
55	locals := space[:nlocals:nlocals] // local variables, starting with parameters
56	stack := space[nlocals:]          // operand stack
57
58	// Digest arguments and set parameters.
59	err := setArgs(locals, fn, args, kwargs)
60	if err != nil {
61		return nil, thread.evalError(err)
62	}
63
64	fr.locals = locals
65
66	if vmdebug {
67		fmt.Printf("Entering %s @ %s\n", f.Name, f.Position(0))
68		fmt.Printf("%d stack, %d locals\n", len(stack), len(locals))
69		defer fmt.Println("Leaving ", f.Name)
70	}
71
72	// Spill indicated locals to cells.
73	// Each cell is a separate alloc to avoid spurious liveness.
74	for _, index := range f.Cells {
75		locals[index] = &cell{locals[index]}
76	}
77
78	// TODO(adonovan): add static check that beneath this point
79	// - there is exactly one return statement
80	// - there is no redefinition of 'err'.
81
82	var iterstack []Iterator // stack of active iterators
83
84	sp := 0
85	var pc uint32
86	var result Value
87	code := f.Code
88loop:
89	for {
90		thread.steps++
91		if thread.steps >= thread.maxSteps {
92			thread.Cancel("too many steps")
93		}
94		if reason := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&thread.cancelReason))); reason != nil {
95			err = fmt.Errorf("Starlark computation cancelled: %s", *(*string)(reason))
96			break loop
97		}
98
99		fr.pc = pc
100
101		op := compile.Opcode(code[pc])
102		pc++
103		var arg uint32
104		if op >= compile.OpcodeArgMin {
105			// TODO(adonovan): opt: profile this.
106			// Perhaps compiling big endian would be less work to decode?
107			for s := uint(0); ; s += 7 {
108				b := code[pc]
109				pc++
110				arg |= uint32(b&0x7f) << s
111				if b < 0x80 {
112					break
113				}
114			}
115		}
116		if vmdebug {
117			fmt.Fprintln(os.Stderr, stack[:sp]) // very verbose!
118			compile.PrintOp(f, fr.pc, op, arg)
119		}
120
121		switch op {
122		case compile.NOP:
123			// nop
124
125		case compile.DUP:
126			stack[sp] = stack[sp-1]
127			sp++
128
129		case compile.DUP2:
130			stack[sp] = stack[sp-2]
131			stack[sp+1] = stack[sp-1]
132			sp += 2
133
134		case compile.POP:
135			sp--
136
137		case compile.EXCH:
138			stack[sp-2], stack[sp-1] = stack[sp-1], stack[sp-2]
139
140		case compile.EQL, compile.NEQ, compile.GT, compile.LT, compile.LE, compile.GE:
141			op := syntax.Token(op-compile.EQL) + syntax.EQL
142			y := stack[sp-1]
143			x := stack[sp-2]
144			sp -= 2
145			ok, err2 := Compare(op, x, y)
146			if err2 != nil {
147				err = err2
148				break loop
149			}
150			stack[sp] = Bool(ok)
151			sp++
152
153		case compile.PLUS,
154			compile.MINUS,
155			compile.STAR,
156			compile.SLASH,
157			compile.SLASHSLASH,
158			compile.PERCENT,
159			compile.AMP,
160			compile.PIPE,
161			compile.CIRCUMFLEX,
162			compile.LTLT,
163			compile.GTGT,
164			compile.IN:
165			binop := syntax.Token(op-compile.PLUS) + syntax.PLUS
166			if op == compile.IN {
167				binop = syntax.IN // IN token is out of order
168			}
169			y := stack[sp-1]
170			x := stack[sp-2]
171			sp -= 2
172			z, err2 := Binary(binop, x, y)
173			if err2 != nil {
174				err = err2
175				break loop
176			}
177			stack[sp] = z
178			sp++
179
180		case compile.UPLUS, compile.UMINUS, compile.TILDE:
181			var unop syntax.Token
182			if op == compile.TILDE {
183				unop = syntax.TILDE
184			} else {
185				unop = syntax.Token(op-compile.UPLUS) + syntax.PLUS
186			}
187			x := stack[sp-1]
188			y, err2 := Unary(unop, x)
189			if err2 != nil {
190				err = err2
191				break loop
192			}
193			stack[sp-1] = y
194
195		case compile.INPLACE_ADD:
196			y := stack[sp-1]
197			x := stack[sp-2]
198			sp -= 2
199
200			// It's possible that y is not Iterable but
201			// nonetheless defines x+y, in which case we
202			// should fall back to the general case.
203			var z Value
204			if xlist, ok := x.(*List); ok {
205				if yiter, ok := y.(Iterable); ok {
206					if err = xlist.checkMutable("apply += to"); err != nil {
207						break loop
208					}
209					listExtend(xlist, yiter)
210					z = xlist
211				}
212			}
213			if z == nil {
214				z, err = Binary(syntax.PLUS, x, y)
215				if err != nil {
216					break loop
217				}
218			}
219
220			stack[sp] = z
221			sp++
222
223		case compile.NONE:
224			stack[sp] = None
225			sp++
226
227		case compile.TRUE:
228			stack[sp] = True
229			sp++
230
231		case compile.FALSE:
232			stack[sp] = False
233			sp++
234
235		case compile.MANDATORY:
236			stack[sp] = mandatory{}
237			sp++
238
239		case compile.JMP:
240			pc = arg
241
242		case compile.CALL, compile.CALL_VAR, compile.CALL_KW, compile.CALL_VAR_KW:
243			var kwargs Value
244			if op == compile.CALL_KW || op == compile.CALL_VAR_KW {
245				kwargs = stack[sp-1]
246				sp--
247			}
248
249			var args Value
250			if op == compile.CALL_VAR || op == compile.CALL_VAR_KW {
251				args = stack[sp-1]
252				sp--
253			}
254
255			// named args (pairs)
256			var kvpairs []Tuple
257			if nkvpairs := int(arg & 0xff); nkvpairs > 0 {
258				kvpairs = make([]Tuple, 0, nkvpairs)
259				kvpairsAlloc := make(Tuple, 2*nkvpairs) // allocate a single backing array
260				sp -= 2 * nkvpairs
261				for i := 0; i < nkvpairs; i++ {
262					pair := kvpairsAlloc[:2:2]
263					kvpairsAlloc = kvpairsAlloc[2:]
264					pair[0] = stack[sp+2*i]   // name
265					pair[1] = stack[sp+2*i+1] // value
266					kvpairs = append(kvpairs, pair)
267				}
268			}
269			if kwargs != nil {
270				// Add key/value items from **kwargs dictionary.
271				dict, ok := kwargs.(IterableMapping)
272				if !ok {
273					err = fmt.Errorf("argument after ** must be a mapping, not %s", kwargs.Type())
274					break loop
275				}
276				items := dict.Items()
277				for _, item := range items {
278					if _, ok := item[0].(String); !ok {
279						err = fmt.Errorf("keywords must be strings, not %s", item[0].Type())
280						break loop
281					}
282				}
283				if len(kvpairs) == 0 {
284					kvpairs = items
285				} else {
286					kvpairs = append(kvpairs, items...)
287				}
288			}
289
290			// positional args
291			var positional Tuple
292			if npos := int(arg >> 8); npos > 0 {
293				positional = stack[sp-npos : sp]
294				sp -= npos
295
296				// Copy positional arguments into a new array,
297				// unless the callee is another Starlark function,
298				// in which case it can be trusted not to mutate them.
299				if _, ok := stack[sp-1].(*Function); !ok || args != nil {
300					positional = append(Tuple(nil), positional...)
301				}
302			}
303			if args != nil {
304				// Add elements from *args sequence.
305				iter := Iterate(args)
306				if iter == nil {
307					err = fmt.Errorf("argument after * must be iterable, not %s", args.Type())
308					break loop
309				}
310				var elem Value
311				for iter.Next(&elem) {
312					positional = append(positional, elem)
313				}
314				iter.Done()
315			}
316
317			function := stack[sp-1]
318
319			if vmdebug {
320				fmt.Printf("VM call %s args=%s kwargs=%s @%s\n",
321					function, positional, kvpairs, f.Position(fr.pc))
322			}
323
324			thread.endProfSpan()
325			z, err2 := Call(thread, function, positional, kvpairs)
326			thread.beginProfSpan()
327			if err2 != nil {
328				err = err2
329				break loop
330			}
331			if vmdebug {
332				fmt.Printf("Resuming %s @ %s\n", f.Name, f.Position(0))
333			}
334			stack[sp-1] = z
335
336		case compile.ITERPUSH:
337			x := stack[sp-1]
338			sp--
339			iter := Iterate(x)
340			if iter == nil {
341				err = fmt.Errorf("%s value is not iterable", x.Type())
342				break loop
343			}
344			iterstack = append(iterstack, iter)
345
346		case compile.ITERJMP:
347			iter := iterstack[len(iterstack)-1]
348			if iter.Next(&stack[sp]) {
349				sp++
350			} else {
351				pc = arg
352			}
353
354		case compile.ITERPOP:
355			n := len(iterstack) - 1
356			iterstack[n].Done()
357			iterstack = iterstack[:n]
358
359		case compile.NOT:
360			stack[sp-1] = !stack[sp-1].Truth()
361
362		case compile.RETURN:
363			result = stack[sp-1]
364			break loop
365
366		case compile.SETINDEX:
367			z := stack[sp-1]
368			y := stack[sp-2]
369			x := stack[sp-3]
370			sp -= 3
371			err = setIndex(x, y, z)
372			if err != nil {
373				break loop
374			}
375
376		case compile.INDEX:
377			y := stack[sp-1]
378			x := stack[sp-2]
379			sp -= 2
380			z, err2 := getIndex(x, y)
381			if err2 != nil {
382				err = err2
383				break loop
384			}
385			stack[sp] = z
386			sp++
387
388		case compile.ATTR:
389			x := stack[sp-1]
390			name := f.Prog.Names[arg]
391			y, err2 := getAttr(x, name)
392			if err2 != nil {
393				err = err2
394				break loop
395			}
396			stack[sp-1] = y
397
398		case compile.SETFIELD:
399			y := stack[sp-1]
400			x := stack[sp-2]
401			sp -= 2
402			name := f.Prog.Names[arg]
403			if err2 := setField(x, name, y); err2 != nil {
404				err = err2
405				break loop
406			}
407
408		case compile.MAKEDICT:
409			stack[sp] = new(Dict)
410			sp++
411
412		case compile.SETDICT, compile.SETDICTUNIQ:
413			dict := stack[sp-3].(*Dict)
414			k := stack[sp-2]
415			v := stack[sp-1]
416			sp -= 3
417			oldlen := dict.Len()
418			if err2 := dict.SetKey(k, v); err2 != nil {
419				err = err2
420				break loop
421			}
422			if op == compile.SETDICTUNIQ && dict.Len() == oldlen {
423				err = fmt.Errorf("duplicate key: %v", k)
424				break loop
425			}
426
427		case compile.APPEND:
428			elem := stack[sp-1]
429			list := stack[sp-2].(*List)
430			sp -= 2
431			list.elems = append(list.elems, elem)
432
433		case compile.SLICE:
434			x := stack[sp-4]
435			lo := stack[sp-3]
436			hi := stack[sp-2]
437			step := stack[sp-1]
438			sp -= 4
439			res, err2 := slice(x, lo, hi, step)
440			if err2 != nil {
441				err = err2
442				break loop
443			}
444			stack[sp] = res
445			sp++
446
447		case compile.UNPACK:
448			n := int(arg)
449			iterable := stack[sp-1]
450			sp--
451			iter := Iterate(iterable)
452			if iter == nil {
453				err = fmt.Errorf("got %s in sequence assignment", iterable.Type())
454				break loop
455			}
456			i := 0
457			sp += n
458			for i < n && iter.Next(&stack[sp-1-i]) {
459				i++
460			}
461			var dummy Value
462			if iter.Next(&dummy) {
463				// NB: Len may return -1 here in obscure cases.
464				err = fmt.Errorf("too many values to unpack (got %d, want %d)", Len(iterable), n)
465				break loop
466			}
467			iter.Done()
468			if i < n {
469				err = fmt.Errorf("too few values to unpack (got %d, want %d)", i, n)
470				break loop
471			}
472
473		case compile.CJMP:
474			if stack[sp-1].Truth() {
475				pc = arg
476			}
477			sp--
478
479		case compile.CONSTANT:
480			stack[sp] = fn.module.constants[arg]
481			sp++
482
483		case compile.MAKETUPLE:
484			n := int(arg)
485			tuple := make(Tuple, n)
486			sp -= n
487			copy(tuple, stack[sp:])
488			stack[sp] = tuple
489			sp++
490
491		case compile.MAKELIST:
492			n := int(arg)
493			elems := make([]Value, n)
494			sp -= n
495			copy(elems, stack[sp:])
496			stack[sp] = NewList(elems)
497			sp++
498
499		case compile.MAKEFUNC:
500			funcode := f.Prog.Functions[arg]
501			tuple := stack[sp-1].(Tuple)
502			n := len(tuple) - len(funcode.Freevars)
503			defaults := tuple[:n:n]
504			freevars := tuple[n:]
505			stack[sp-1] = &Function{
506				funcode:  funcode,
507				module:   fn.module,
508				defaults: defaults,
509				freevars: freevars,
510			}
511
512		case compile.LOAD:
513			n := int(arg)
514			module := string(stack[sp-1].(String))
515			sp--
516
517			if thread.Load == nil {
518				err = fmt.Errorf("load not implemented by this application")
519				break loop
520			}
521
522			thread.endProfSpan()
523			dict, err2 := thread.Load(thread, module)
524			thread.beginProfSpan()
525			if err2 != nil {
526				err = wrappedError{
527					msg:   fmt.Sprintf("cannot load %s: %v", module, err2),
528					cause: err2,
529				}
530				break loop
531			}
532
533			for i := 0; i < n; i++ {
534				from := string(stack[sp-1-i].(String))
535				v, ok := dict[from]
536				if !ok {
537					err = fmt.Errorf("load: name %s not found in module %s", from, module)
538					if n := spell.Nearest(from, dict.Keys()); n != "" {
539						err = fmt.Errorf("%s (did you mean %s?)", err, n)
540					}
541					break loop
542				}
543				stack[sp-1-i] = v
544			}
545
546		case compile.SETLOCAL:
547			locals[arg] = stack[sp-1]
548			sp--
549
550		case compile.SETLOCALCELL:
551			locals[arg].(*cell).v = stack[sp-1]
552			sp--
553
554		case compile.SETGLOBAL:
555			fn.module.globals[arg] = stack[sp-1]
556			sp--
557
558		case compile.LOCAL:
559			x := locals[arg]
560			if x == nil {
561				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
562				break loop
563			}
564			stack[sp] = x
565			sp++
566
567		case compile.FREE:
568			stack[sp] = fn.freevars[arg]
569			sp++
570
571		case compile.LOCALCELL:
572			v := locals[arg].(*cell).v
573			if v == nil {
574				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
575				break loop
576			}
577			stack[sp] = v
578			sp++
579
580		case compile.FREECELL:
581			v := fn.freevars[arg].(*cell).v
582			if v == nil {
583				err = fmt.Errorf("local variable %s referenced before assignment", f.Freevars[arg].Name)
584				break loop
585			}
586			stack[sp] = v
587			sp++
588
589		case compile.GLOBAL:
590			x := fn.module.globals[arg]
591			if x == nil {
592				err = fmt.Errorf("global variable %s referenced before assignment", f.Prog.Globals[arg].Name)
593				break loop
594			}
595			stack[sp] = x
596			sp++
597
598		case compile.PREDECLARED:
599			name := f.Prog.Names[arg]
600			x := fn.module.predeclared[name]
601			if x == nil {
602				err = fmt.Errorf("internal error: predeclared variable %s is uninitialized", name)
603				break loop
604			}
605			stack[sp] = x
606			sp++
607
608		case compile.UNIVERSAL:
609			stack[sp] = Universe[f.Prog.Names[arg]]
610			sp++
611
612		default:
613			err = fmt.Errorf("unimplemented: %s", op)
614			break loop
615		}
616	}
617
618	// ITERPOP the rest of the iterator stack.
619	for _, iter := range iterstack {
620		iter.Done()
621	}
622
623	fr.locals = nil
624
625	return result, err
626}
627
628type wrappedError struct {
629	msg   string
630	cause error
631}
632
633func (e wrappedError) Error() string {
634	return e.msg
635}
636
637// Implements the xerrors.Wrapper interface
638// https://godoc.org/golang.org/x/xerrors#Wrapper
639func (e wrappedError) Unwrap() error {
640	return e.cause
641}
642
643// mandatory is a sentinel value used in a function's defaults tuple
644// to indicate that a (keyword-only) parameter is mandatory.
645type mandatory struct{}
646
647func (mandatory) String() string        { return "mandatory" }
648func (mandatory) Type() string          { return "mandatory" }
649func (mandatory) Freeze()               {} // immutable
650func (mandatory) Truth() Bool           { return False }
651func (mandatory) Hash() (uint32, error) { return 0, nil }
652
653// A cell is a box containing a Value.
654// Local variables marked as cells hold their value indirectly
655// so that they may be shared by outer and inner nested functions.
656// Cells are always accessed using indirect {FREE,LOCAL,SETLOCAL}CELL instructions.
657// The FreeVars tuple contains only cells.
658// The FREE instruction always yields a cell.
659type cell struct{ v Value }
660
661func (c *cell) String() string { return "cell" }
662func (c *cell) Type() string   { return "cell" }
663func (c *cell) Freeze() {
664	if c.v != nil {
665		c.v.Freeze()
666	}
667}
668func (c *cell) Truth() Bool           { panic("unreachable") }
669func (c *cell) Hash() (uint32, error) { panic("unreachable") }
670