1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package constant implements Values representing untyped
6// Go constants and their corresponding operations.
7//
8// A special Unknown value may be used when a value
9// is unknown due to an error. Operations on unknown
10// values produce unknown values unless specified
11// otherwise.
12package constant
13
14import (
15	"fmt"
16	"go/token"
17	"math"
18	"math/big"
19	"math/bits"
20	"strconv"
21	"strings"
22	"sync"
23	"unicode/utf8"
24)
25
26//go:generate stringer -type Kind
27
28// Kind specifies the kind of value represented by a [Value].
29type Kind int
30
31const (
32	// unknown values
33	Unknown Kind = iota
34
35	// non-numeric values
36	Bool
37	String
38
39	// numeric values
40	Int
41	Float
42	Complex
43)
44
45// A Value represents the value of a Go constant.
46type Value interface {
47	// Kind returns the value kind.
48	Kind() Kind
49
50	// String returns a short, quoted (human-readable) form of the value.
51	// For numeric values, the result may be an approximation;
52	// for String values the result may be a shortened string.
53	// Use ExactString for a string representing a value exactly.
54	String() string
55
56	// ExactString returns an exact, quoted (human-readable) form of the value.
57	// If the Value is of Kind String, use StringVal to obtain the unquoted string.
58	ExactString() string
59
60	// Prevent external implementations.
61	implementsValue()
62}
63
64// ----------------------------------------------------------------------------
65// Implementations
66
67// Maximum supported mantissa precision.
68// The spec requires at least 256 bits; typical implementations use 512 bits.
69const prec = 512
70
71// TODO(gri) Consider storing "error" information in an unknownVal so clients
72// can provide better error messages. For instance, if a number is
73// too large (incl. infinity), that could be recorded in unknownVal.
74// See also #20583 and #42695 for use cases.
75
76// Representation of values:
77//
78// Values of Int and Float Kind have two different representations each: int64Val
79// and intVal, and ratVal and floatVal. When possible, the "smaller", respectively
80// more precise (for Floats) representation is chosen. However, once a Float value
81// is represented as a floatVal, any subsequent results remain floatVals (unless
82// explicitly converted); i.e., no attempt is made to convert a floatVal back into
83// a ratVal. The reasoning is that all representations but floatVal are mathematically
84// exact, but once that precision is lost (by moving to floatVal), moving back to
85// a different representation implies a precision that's not actually there.
86
87type (
88	unknownVal struct{}
89	boolVal    bool
90	stringVal  struct {
91		// Lazy value: either a string (l,r==nil) or an addition (l,r!=nil).
92		mu   sync.Mutex
93		s    string
94		l, r *stringVal
95	}
96	int64Val   int64                    // Int values representable as an int64
97	intVal     struct{ val *big.Int }   // Int values not representable as an int64
98	ratVal     struct{ val *big.Rat }   // Float values representable as a fraction
99	floatVal   struct{ val *big.Float } // Float values not representable as a fraction
100	complexVal struct{ re, im Value }
101)
102
103func (unknownVal) Kind() Kind { return Unknown }
104func (boolVal) Kind() Kind    { return Bool }
105func (*stringVal) Kind() Kind { return String }
106func (int64Val) Kind() Kind   { return Int }
107func (intVal) Kind() Kind     { return Int }
108func (ratVal) Kind() Kind     { return Float }
109func (floatVal) Kind() Kind   { return Float }
110func (complexVal) Kind() Kind { return Complex }
111
112func (unknownVal) String() string { return "unknown" }
113func (x boolVal) String() string  { return strconv.FormatBool(bool(x)) }
114
115// String returns a possibly shortened quoted form of the String value.
116func (x *stringVal) String() string {
117	const maxLen = 72 // a reasonable length
118	s := strconv.Quote(x.string())
119	if utf8.RuneCountInString(s) > maxLen {
120		// The string without the enclosing quotes is greater than maxLen-2 runes
121		// long. Remove the last 3 runes (including the closing '"') by keeping
122		// only the first maxLen-3 runes; then add "...".
123		i := 0
124		for n := 0; n < maxLen-3; n++ {
125			_, size := utf8.DecodeRuneInString(s[i:])
126			i += size
127		}
128		s = s[:i] + "..."
129	}
130	return s
131}
132
133// string constructs and returns the actual string literal value.
134// If x represents an addition, then it rewrites x to be a single
135// string, to speed future calls. This lazy construction avoids
136// building different string values for all subpieces of a large
137// concatenation. See golang.org/issue/23348.
138func (x *stringVal) string() string {
139	x.mu.Lock()
140	if x.l != nil {
141		x.s = strings.Join(reverse(x.appendReverse(nil)), "")
142		x.l = nil
143		x.r = nil
144	}
145	s := x.s
146	x.mu.Unlock()
147
148	return s
149}
150
151// reverse reverses x in place and returns it.
152func reverse(x []string) []string {
153	n := len(x)
154	for i := 0; i+i < n; i++ {
155		x[i], x[n-1-i] = x[n-1-i], x[i]
156	}
157	return x
158}
159
160// appendReverse appends to list all of x's subpieces, but in reverse,
161// and returns the result. Appending the reversal allows processing
162// the right side in a recursive call and the left side in a loop.
163// Because a chain like a + b + c + d + e is actually represented
164// as ((((a + b) + c) + d) + e), the left-side loop avoids deep recursion.
165// x must be locked.
166func (x *stringVal) appendReverse(list []string) []string {
167	y := x
168	for y.r != nil {
169		y.r.mu.Lock()
170		list = y.r.appendReverse(list)
171		y.r.mu.Unlock()
172
173		l := y.l
174		if y != x {
175			y.mu.Unlock()
176		}
177		l.mu.Lock()
178		y = l
179	}
180	s := y.s
181	if y != x {
182		y.mu.Unlock()
183	}
184	return append(list, s)
185}
186
187func (x int64Val) String() string { return strconv.FormatInt(int64(x), 10) }
188func (x intVal) String() string   { return x.val.String() }
189func (x ratVal) String() string   { return rtof(x).String() }
190
191// String returns a decimal approximation of the Float value.
192func (x floatVal) String() string {
193	f := x.val
194
195	// Don't try to convert infinities (will not terminate).
196	if f.IsInf() {
197		return f.String()
198	}
199
200	// Use exact fmt formatting if in float64 range (common case):
201	// proceed if f doesn't underflow to 0 or overflow to inf.
202	if x, _ := f.Float64(); f.Sign() == 0 == (x == 0) && !math.IsInf(x, 0) {
203		s := fmt.Sprintf("%.6g", x)
204		if !f.IsInt() && strings.IndexByte(s, '.') < 0 {
205			// f is not an integer, but its string representation
206			// doesn't reflect that. Use more digits. See issue 56220.
207			s = fmt.Sprintf("%g", x)
208		}
209		return s
210	}
211
212	// Out of float64 range. Do approximate manual to decimal
213	// conversion to avoid precise but possibly slow Float
214	// formatting.
215	// f = mant * 2**exp
216	var mant big.Float
217	exp := f.MantExp(&mant) // 0.5 <= |mant| < 1.0
218
219	// approximate float64 mantissa m and decimal exponent d
220	// f ~ m * 10**d
221	m, _ := mant.Float64()                     // 0.5 <= |m| < 1.0
222	d := float64(exp) * (math.Ln2 / math.Ln10) // log_10(2)
223
224	// adjust m for truncated (integer) decimal exponent e
225	e := int64(d)
226	m *= math.Pow(10, d-float64(e))
227
228	// ensure 1 <= |m| < 10
229	switch am := math.Abs(m); {
230	case am < 1-0.5e-6:
231		// The %.6g format below rounds m to 5 digits after the
232		// decimal point. Make sure that m*10 < 10 even after
233		// rounding up: m*10 + 0.5e-5 < 10 => m < 1 - 0.5e6.
234		m *= 10
235		e--
236	case am >= 10:
237		m /= 10
238		e++
239	}
240
241	return fmt.Sprintf("%.6ge%+d", m, e)
242}
243
244func (x complexVal) String() string { return fmt.Sprintf("(%s + %si)", x.re, x.im) }
245
246func (x unknownVal) ExactString() string { return x.String() }
247func (x boolVal) ExactString() string    { return x.String() }
248func (x *stringVal) ExactString() string { return strconv.Quote(x.string()) }
249func (x int64Val) ExactString() string   { return x.String() }
250func (x intVal) ExactString() string     { return x.String() }
251
252func (x ratVal) ExactString() string {
253	r := x.val
254	if r.IsInt() {
255		return r.Num().String()
256	}
257	return r.String()
258}
259
260func (x floatVal) ExactString() string { return x.val.Text('p', 0) }
261
262func (x complexVal) ExactString() string {
263	return fmt.Sprintf("(%s + %si)", x.re.ExactString(), x.im.ExactString())
264}
265
266func (unknownVal) implementsValue() {}
267func (boolVal) implementsValue()    {}
268func (*stringVal) implementsValue() {}
269func (int64Val) implementsValue()   {}
270func (ratVal) implementsValue()     {}
271func (intVal) implementsValue()     {}
272func (floatVal) implementsValue()   {}
273func (complexVal) implementsValue() {}
274
275func newInt() *big.Int     { return new(big.Int) }
276func newRat() *big.Rat     { return new(big.Rat) }
277func newFloat() *big.Float { return new(big.Float).SetPrec(prec) }
278
279func i64toi(x int64Val) intVal   { return intVal{newInt().SetInt64(int64(x))} }
280func i64tor(x int64Val) ratVal   { return ratVal{newRat().SetInt64(int64(x))} }
281func i64tof(x int64Val) floatVal { return floatVal{newFloat().SetInt64(int64(x))} }
282func itor(x intVal) ratVal       { return ratVal{newRat().SetInt(x.val)} }
283func itof(x intVal) floatVal     { return floatVal{newFloat().SetInt(x.val)} }
284func rtof(x ratVal) floatVal     { return floatVal{newFloat().SetRat(x.val)} }
285func vtoc(x Value) complexVal    { return complexVal{x, int64Val(0)} }
286
287func makeInt(x *big.Int) Value {
288	if x.IsInt64() {
289		return int64Val(x.Int64())
290	}
291	return intVal{x}
292}
293
294func makeRat(x *big.Rat) Value {
295	a := x.Num()
296	b := x.Denom()
297	if smallInt(a) && smallInt(b) {
298		// ok to remain fraction
299		return ratVal{x}
300	}
301	// components too large => switch to float
302	return floatVal{newFloat().SetRat(x)}
303}
304
305var floatVal0 = floatVal{newFloat()}
306
307func makeFloat(x *big.Float) Value {
308	// convert -0
309	if x.Sign() == 0 {
310		return floatVal0
311	}
312	if x.IsInf() {
313		return unknownVal{}
314	}
315	// No attempt is made to "go back" to ratVal, even if possible,
316	// to avoid providing the illusion of a mathematically exact
317	// representation.
318	return floatVal{x}
319}
320
321func makeComplex(re, im Value) Value {
322	if re.Kind() == Unknown || im.Kind() == Unknown {
323		return unknownVal{}
324	}
325	return complexVal{re, im}
326}
327
328func makeFloatFromLiteral(lit string) Value {
329	if f, ok := newFloat().SetString(lit); ok {
330		if smallFloat(f) {
331			// ok to use rationals
332			if f.Sign() == 0 {
333				// Issue 20228: If the float underflowed to zero, parse just "0".
334				// Otherwise, lit might contain a value with a large negative exponent,
335				// such as -6e-1886451601. As a float, that will underflow to 0,
336				// but it'll take forever to parse as a Rat.
337				lit = "0"
338			}
339			if r, ok := newRat().SetString(lit); ok {
340				return ratVal{r}
341			}
342		}
343		// otherwise use floats
344		return makeFloat(f)
345	}
346	return nil
347}
348
349// Permit fractions with component sizes up to maxExp
350// before switching to using floating-point numbers.
351const maxExp = 4 << 10
352
353// smallInt reports whether x would lead to "reasonably"-sized fraction
354// if converted to a *big.Rat.
355func smallInt(x *big.Int) bool {
356	return x.BitLen() < maxExp
357}
358
359// smallFloat64 reports whether x would lead to "reasonably"-sized fraction
360// if converted to a *big.Rat.
361func smallFloat64(x float64) bool {
362	if math.IsInf(x, 0) {
363		return false
364	}
365	_, e := math.Frexp(x)
366	return -maxExp < e && e < maxExp
367}
368
369// smallFloat reports whether x would lead to "reasonably"-sized fraction
370// if converted to a *big.Rat.
371func smallFloat(x *big.Float) bool {
372	if x.IsInf() {
373		return false
374	}
375	e := x.MantExp(nil)
376	return -maxExp < e && e < maxExp
377}
378
379// ----------------------------------------------------------------------------
380// Factories
381
382// MakeUnknown returns the [Unknown] value.
383func MakeUnknown() Value { return unknownVal{} }
384
385// MakeBool returns the [Bool] value for b.
386func MakeBool(b bool) Value { return boolVal(b) }
387
388// MakeString returns the [String] value for s.
389func MakeString(s string) Value {
390	if s == "" {
391		return &emptyString // common case
392	}
393	return &stringVal{s: s}
394}
395
396var emptyString stringVal
397
398// MakeInt64 returns the [Int] value for x.
399func MakeInt64(x int64) Value { return int64Val(x) }
400
401// MakeUint64 returns the [Int] value for x.
402func MakeUint64(x uint64) Value {
403	if x < 1<<63 {
404		return int64Val(int64(x))
405	}
406	return intVal{newInt().SetUint64(x)}
407}
408
409// MakeFloat64 returns the [Float] value for x.
410// If x is -0.0, the result is 0.0.
411// If x is not finite, the result is an [Unknown].
412func MakeFloat64(x float64) Value {
413	if math.IsInf(x, 0) || math.IsNaN(x) {
414		return unknownVal{}
415	}
416	if smallFloat64(x) {
417		return ratVal{newRat().SetFloat64(x + 0)} // convert -0 to 0
418	}
419	return floatVal{newFloat().SetFloat64(x + 0)}
420}
421
422// MakeFromLiteral returns the corresponding integer, floating-point,
423// imaginary, character, or string value for a Go literal string. The
424// tok value must be one of [token.INT], [token.FLOAT], [token.IMAG],
425// [token.CHAR], or [token.STRING]. The final argument must be zero.
426// If the literal string syntax is invalid, the result is an [Unknown].
427func MakeFromLiteral(lit string, tok token.Token, zero uint) Value {
428	if zero != 0 {
429		panic("MakeFromLiteral called with non-zero last argument")
430	}
431
432	switch tok {
433	case token.INT:
434		if x, err := strconv.ParseInt(lit, 0, 64); err == nil {
435			return int64Val(x)
436		}
437		if x, ok := newInt().SetString(lit, 0); ok {
438			return intVal{x}
439		}
440
441	case token.FLOAT:
442		if x := makeFloatFromLiteral(lit); x != nil {
443			return x
444		}
445
446	case token.IMAG:
447		if n := len(lit); n > 0 && lit[n-1] == 'i' {
448			if im := makeFloatFromLiteral(lit[:n-1]); im != nil {
449				return makeComplex(int64Val(0), im)
450			}
451		}
452
453	case token.CHAR:
454		if n := len(lit); n >= 2 {
455			if code, _, _, err := strconv.UnquoteChar(lit[1:n-1], '\''); err == nil {
456				return MakeInt64(int64(code))
457			}
458		}
459
460	case token.STRING:
461		if s, err := strconv.Unquote(lit); err == nil {
462			return MakeString(s)
463		}
464
465	default:
466		panic(fmt.Sprintf("%v is not a valid token", tok))
467	}
468
469	return unknownVal{}
470}
471
472// ----------------------------------------------------------------------------
473// Accessors
474//
475// For unknown arguments the result is the zero value for the respective
476// accessor type, except for Sign, where the result is 1.
477
478// BoolVal returns the Go boolean value of x, which must be a [Bool] or an [Unknown].
479// If x is [Unknown], the result is false.
480func BoolVal(x Value) bool {
481	switch x := x.(type) {
482	case boolVal:
483		return bool(x)
484	case unknownVal:
485		return false
486	default:
487		panic(fmt.Sprintf("%v not a Bool", x))
488	}
489}
490
491// StringVal returns the Go string value of x, which must be a [String] or an [Unknown].
492// If x is [Unknown], the result is "".
493func StringVal(x Value) string {
494	switch x := x.(type) {
495	case *stringVal:
496		return x.string()
497	case unknownVal:
498		return ""
499	default:
500		panic(fmt.Sprintf("%v not a String", x))
501	}
502}
503
504// Int64Val returns the Go int64 value of x and whether the result is exact;
505// x must be an [Int] or an [Unknown]. If the result is not exact, its value is undefined.
506// If x is [Unknown], the result is (0, false).
507func Int64Val(x Value) (int64, bool) {
508	switch x := x.(type) {
509	case int64Val:
510		return int64(x), true
511	case intVal:
512		return x.val.Int64(), false // not an int64Val and thus not exact
513	case unknownVal:
514		return 0, false
515	default:
516		panic(fmt.Sprintf("%v not an Int", x))
517	}
518}
519
520// Uint64Val returns the Go uint64 value of x and whether the result is exact;
521// x must be an [Int] or an [Unknown]. If the result is not exact, its value is undefined.
522// If x is [Unknown], the result is (0, false).
523func Uint64Val(x Value) (uint64, bool) {
524	switch x := x.(type) {
525	case int64Val:
526		return uint64(x), x >= 0
527	case intVal:
528		return x.val.Uint64(), x.val.IsUint64()
529	case unknownVal:
530		return 0, false
531	default:
532		panic(fmt.Sprintf("%v not an Int", x))
533	}
534}
535
536// Float32Val is like [Float64Val] but for float32 instead of float64.
537func Float32Val(x Value) (float32, bool) {
538	switch x := x.(type) {
539	case int64Val:
540		f := float32(x)
541		return f, int64Val(f) == x
542	case intVal:
543		f, acc := newFloat().SetInt(x.val).Float32()
544		return f, acc == big.Exact
545	case ratVal:
546		return x.val.Float32()
547	case floatVal:
548		f, acc := x.val.Float32()
549		return f, acc == big.Exact
550	case unknownVal:
551		return 0, false
552	default:
553		panic(fmt.Sprintf("%v not a Float", x))
554	}
555}
556
557// Float64Val returns the nearest Go float64 value of x and whether the result is exact;
558// x must be numeric or an [Unknown], but not [Complex]. For values too small (too close to 0)
559// to represent as float64, [Float64Val] silently underflows to 0. The result sign always
560// matches the sign of x, even for 0.
561// If x is [Unknown], the result is (0, false).
562func Float64Val(x Value) (float64, bool) {
563	switch x := x.(type) {
564	case int64Val:
565		f := float64(int64(x))
566		return f, int64Val(f) == x
567	case intVal:
568		f, acc := newFloat().SetInt(x.val).Float64()
569		return f, acc == big.Exact
570	case ratVal:
571		return x.val.Float64()
572	case floatVal:
573		f, acc := x.val.Float64()
574		return f, acc == big.Exact
575	case unknownVal:
576		return 0, false
577	default:
578		panic(fmt.Sprintf("%v not a Float", x))
579	}
580}
581
582// Val returns the underlying value for a given constant. Since it returns an
583// interface, it is up to the caller to type assert the result to the expected
584// type. The possible dynamic return types are:
585//
586//	x Kind             type of result
587//	-----------------------------------------
588//	Bool               bool
589//	String             string
590//	Int                int64 or *big.Int
591//	Float              *big.Float or *big.Rat
592//	everything else    nil
593func Val(x Value) any {
594	switch x := x.(type) {
595	case boolVal:
596		return bool(x)
597	case *stringVal:
598		return x.string()
599	case int64Val:
600		return int64(x)
601	case intVal:
602		return x.val
603	case ratVal:
604		return x.val
605	case floatVal:
606		return x.val
607	default:
608		return nil
609	}
610}
611
612// Make returns the [Value] for x.
613//
614//	type of x        result Kind
615//	----------------------------
616//	bool             Bool
617//	string           String
618//	int64            Int
619//	*big.Int         Int
620//	*big.Float       Float
621//	*big.Rat         Float
622//	anything else    Unknown
623func Make(x any) Value {
624	switch x := x.(type) {
625	case bool:
626		return boolVal(x)
627	case string:
628		return &stringVal{s: x}
629	case int64:
630		return int64Val(x)
631	case *big.Int:
632		return makeInt(x)
633	case *big.Rat:
634		return makeRat(x)
635	case *big.Float:
636		return makeFloat(x)
637	default:
638		return unknownVal{}
639	}
640}
641
642// BitLen returns the number of bits required to represent
643// the absolute value x in binary representation; x must be an [Int] or an [Unknown].
644// If x is [Unknown], the result is 0.
645func BitLen(x Value) int {
646	switch x := x.(type) {
647	case int64Val:
648		u := uint64(x)
649		if x < 0 {
650			u = uint64(-x)
651		}
652		return 64 - bits.LeadingZeros64(u)
653	case intVal:
654		return x.val.BitLen()
655	case unknownVal:
656		return 0
657	default:
658		panic(fmt.Sprintf("%v not an Int", x))
659	}
660}
661
662// Sign returns -1, 0, or 1 depending on whether x < 0, x == 0, or x > 0;
663// x must be numeric or [Unknown]. For complex values x, the sign is 0 if x == 0,
664// otherwise it is != 0. If x is [Unknown], the result is 1.
665func Sign(x Value) int {
666	switch x := x.(type) {
667	case int64Val:
668		switch {
669		case x < 0:
670			return -1
671		case x > 0:
672			return 1
673		}
674		return 0
675	case intVal:
676		return x.val.Sign()
677	case ratVal:
678		return x.val.Sign()
679	case floatVal:
680		return x.val.Sign()
681	case complexVal:
682		return Sign(x.re) | Sign(x.im)
683	case unknownVal:
684		return 1 // avoid spurious division by zero errors
685	default:
686		panic(fmt.Sprintf("%v not numeric", x))
687	}
688}
689
690// ----------------------------------------------------------------------------
691// Support for assembling/disassembling numeric values
692
693const (
694	// Compute the size of a Word in bytes.
695	_m       = ^big.Word(0)
696	_log     = _m>>8&1 + _m>>16&1 + _m>>32&1
697	wordSize = 1 << _log
698)
699
700// Bytes returns the bytes for the absolute value of x in little-
701// endian binary representation; x must be an [Int].
702func Bytes(x Value) []byte {
703	var t intVal
704	switch x := x.(type) {
705	case int64Val:
706		t = i64toi(x)
707	case intVal:
708		t = x
709	default:
710		panic(fmt.Sprintf("%v not an Int", x))
711	}
712
713	words := t.val.Bits()
714	bytes := make([]byte, len(words)*wordSize)
715
716	i := 0
717	for _, w := range words {
718		for j := 0; j < wordSize; j++ {
719			bytes[i] = byte(w)
720			w >>= 8
721			i++
722		}
723	}
724	// remove leading 0's
725	for i > 0 && bytes[i-1] == 0 {
726		i--
727	}
728
729	return bytes[:i]
730}
731
732// MakeFromBytes returns the [Int] value given the bytes of its little-endian
733// binary representation. An empty byte slice argument represents 0.
734func MakeFromBytes(bytes []byte) Value {
735	words := make([]big.Word, (len(bytes)+(wordSize-1))/wordSize)
736
737	i := 0
738	var w big.Word
739	var s uint
740	for _, b := range bytes {
741		w |= big.Word(b) << s
742		if s += 8; s == wordSize*8 {
743			words[i] = w
744			i++
745			w = 0
746			s = 0
747		}
748	}
749	// store last word
750	if i < len(words) {
751		words[i] = w
752		i++
753	}
754	// remove leading 0's
755	for i > 0 && words[i-1] == 0 {
756		i--
757	}
758
759	return makeInt(newInt().SetBits(words[:i]))
760}
761
762// Num returns the numerator of x; x must be [Int], [Float], or [Unknown].
763// If x is [Unknown], or if it is too large or small to represent as a
764// fraction, the result is [Unknown]. Otherwise the result is an [Int]
765// with the same sign as x.
766func Num(x Value) Value {
767	switch x := x.(type) {
768	case int64Val, intVal:
769		return x
770	case ratVal:
771		return makeInt(x.val.Num())
772	case floatVal:
773		if smallFloat(x.val) {
774			r, _ := x.val.Rat(nil)
775			return makeInt(r.Num())
776		}
777	case unknownVal:
778		break
779	default:
780		panic(fmt.Sprintf("%v not Int or Float", x))
781	}
782	return unknownVal{}
783}
784
785// Denom returns the denominator of x; x must be [Int], [Float], or [Unknown].
786// If x is [Unknown], or if it is too large or small to represent as a
787// fraction, the result is [Unknown]. Otherwise the result is an [Int] >= 1.
788func Denom(x Value) Value {
789	switch x := x.(type) {
790	case int64Val, intVal:
791		return int64Val(1)
792	case ratVal:
793		return makeInt(x.val.Denom())
794	case floatVal:
795		if smallFloat(x.val) {
796			r, _ := x.val.Rat(nil)
797			return makeInt(r.Denom())
798		}
799	case unknownVal:
800		break
801	default:
802		panic(fmt.Sprintf("%v not Int or Float", x))
803	}
804	return unknownVal{}
805}
806
807// MakeImag returns the [Complex] value x*i;
808// x must be [Int], [Float], or [Unknown].
809// If x is [Unknown], the result is [Unknown].
810func MakeImag(x Value) Value {
811	switch x.(type) {
812	case unknownVal:
813		return x
814	case int64Val, intVal, ratVal, floatVal:
815		return makeComplex(int64Val(0), x)
816	default:
817		panic(fmt.Sprintf("%v not Int or Float", x))
818	}
819}
820
821// Real returns the real part of x, which must be a numeric or unknown value.
822// If x is [Unknown], the result is [Unknown].
823func Real(x Value) Value {
824	switch x := x.(type) {
825	case unknownVal, int64Val, intVal, ratVal, floatVal:
826		return x
827	case complexVal:
828		return x.re
829	default:
830		panic(fmt.Sprintf("%v not numeric", x))
831	}
832}
833
834// Imag returns the imaginary part of x, which must be a numeric or unknown value.
835// If x is [Unknown], the result is [Unknown].
836func Imag(x Value) Value {
837	switch x := x.(type) {
838	case unknownVal:
839		return x
840	case int64Val, intVal, ratVal, floatVal:
841		return int64Val(0)
842	case complexVal:
843		return x.im
844	default:
845		panic(fmt.Sprintf("%v not numeric", x))
846	}
847}
848
849// ----------------------------------------------------------------------------
850// Numeric conversions
851
852// ToInt converts x to an [Int] value if x is representable as an [Int].
853// Otherwise it returns an [Unknown].
854func ToInt(x Value) Value {
855	switch x := x.(type) {
856	case int64Val, intVal:
857		return x
858
859	case ratVal:
860		if x.val.IsInt() {
861			return makeInt(x.val.Num())
862		}
863
864	case floatVal:
865		// avoid creation of huge integers
866		// (Existing tests require permitting exponents of at least 1024;
867		// allow any value that would also be permissible as a fraction.)
868		if smallFloat(x.val) {
869			i := newInt()
870			if _, acc := x.val.Int(i); acc == big.Exact {
871				return makeInt(i)
872			}
873
874			// If we can get an integer by rounding up or down,
875			// assume x is not an integer because of rounding
876			// errors in prior computations.
877
878			const delta = 4 // a small number of bits > 0
879			var t big.Float
880			t.SetPrec(prec - delta)
881
882			// try rounding down a little
883			t.SetMode(big.ToZero)
884			t.Set(x.val)
885			if _, acc := t.Int(i); acc == big.Exact {
886				return makeInt(i)
887			}
888
889			// try rounding up a little
890			t.SetMode(big.AwayFromZero)
891			t.Set(x.val)
892			if _, acc := t.Int(i); acc == big.Exact {
893				return makeInt(i)
894			}
895		}
896
897	case complexVal:
898		if re := ToFloat(x); re.Kind() == Float {
899			return ToInt(re)
900		}
901	}
902
903	return unknownVal{}
904}
905
906// ToFloat converts x to a [Float] value if x is representable as a [Float].
907// Otherwise it returns an [Unknown].
908func ToFloat(x Value) Value {
909	switch x := x.(type) {
910	case int64Val:
911		return i64tor(x) // x is always a small int
912	case intVal:
913		if smallInt(x.val) {
914			return itor(x)
915		}
916		return itof(x)
917	case ratVal, floatVal:
918		return x
919	case complexVal:
920		if Sign(x.im) == 0 {
921			return ToFloat(x.re)
922		}
923	}
924	return unknownVal{}
925}
926
927// ToComplex converts x to a [Complex] value if x is representable as a [Complex].
928// Otherwise it returns an [Unknown].
929func ToComplex(x Value) Value {
930	switch x := x.(type) {
931	case int64Val, intVal, ratVal, floatVal:
932		return vtoc(x)
933	case complexVal:
934		return x
935	}
936	return unknownVal{}
937}
938
939// ----------------------------------------------------------------------------
940// Operations
941
942// is32bit reports whether x can be represented using 32 bits.
943func is32bit(x int64) bool {
944	const s = 32
945	return -1<<(s-1) <= x && x <= 1<<(s-1)-1
946}
947
948// is63bit reports whether x can be represented using 63 bits.
949func is63bit(x int64) bool {
950	const s = 63
951	return -1<<(s-1) <= x && x <= 1<<(s-1)-1
952}
953
954// UnaryOp returns the result of the unary expression op y.
955// The operation must be defined for the operand.
956// If prec > 0 it specifies the ^ (xor) result size in bits.
957// If y is [Unknown], the result is [Unknown].
958func UnaryOp(op token.Token, y Value, prec uint) Value {
959	switch op {
960	case token.ADD:
961		switch y.(type) {
962		case unknownVal, int64Val, intVal, ratVal, floatVal, complexVal:
963			return y
964		}
965
966	case token.SUB:
967		switch y := y.(type) {
968		case unknownVal:
969			return y
970		case int64Val:
971			if z := -y; z != y {
972				return z // no overflow
973			}
974			return makeInt(newInt().Neg(big.NewInt(int64(y))))
975		case intVal:
976			return makeInt(newInt().Neg(y.val))
977		case ratVal:
978			return makeRat(newRat().Neg(y.val))
979		case floatVal:
980			return makeFloat(newFloat().Neg(y.val))
981		case complexVal:
982			re := UnaryOp(token.SUB, y.re, 0)
983			im := UnaryOp(token.SUB, y.im, 0)
984			return makeComplex(re, im)
985		}
986
987	case token.XOR:
988		z := newInt()
989		switch y := y.(type) {
990		case unknownVal:
991			return y
992		case int64Val:
993			z.Not(big.NewInt(int64(y)))
994		case intVal:
995			z.Not(y.val)
996		default:
997			goto Error
998		}
999		// For unsigned types, the result will be negative and
1000		// thus "too large": We must limit the result precision
1001		// to the type's precision.
1002		if prec > 0 {
1003			z.AndNot(z, newInt().Lsh(big.NewInt(-1), prec)) // z &^= (-1)<<prec
1004		}
1005		return makeInt(z)
1006
1007	case token.NOT:
1008		switch y := y.(type) {
1009		case unknownVal:
1010			return y
1011		case boolVal:
1012			return !y
1013		}
1014	}
1015
1016Error:
1017	panic(fmt.Sprintf("invalid unary operation %s%v", op, y))
1018}
1019
1020func ord(x Value) int {
1021	switch x.(type) {
1022	default:
1023		// force invalid value into "x position" in match
1024		// (don't panic here so that callers can provide a better error message)
1025		return -1
1026	case unknownVal:
1027		return 0
1028	case boolVal, *stringVal:
1029		return 1
1030	case int64Val:
1031		return 2
1032	case intVal:
1033		return 3
1034	case ratVal:
1035		return 4
1036	case floatVal:
1037		return 5
1038	case complexVal:
1039		return 6
1040	}
1041}
1042
1043// match returns the matching representation (same type) with the
1044// smallest complexity for two values x and y. If one of them is
1045// numeric, both of them must be numeric. If one of them is Unknown
1046// or invalid (say, nil) both results are that value.
1047func match(x, y Value) (_, _ Value) {
1048	switch ox, oy := ord(x), ord(y); {
1049	case ox < oy:
1050		x, y = match0(x, y)
1051	case ox > oy:
1052		y, x = match0(y, x)
1053	}
1054	return x, y
1055}
1056
1057// match0 must only be called by match.
1058// Invariant: ord(x) < ord(y)
1059func match0(x, y Value) (_, _ Value) {
1060	// Prefer to return the original x and y arguments when possible,
1061	// to avoid unnecessary heap allocations.
1062
1063	switch y.(type) {
1064	case intVal:
1065		switch x1 := x.(type) {
1066		case int64Val:
1067			return i64toi(x1), y
1068		}
1069	case ratVal:
1070		switch x1 := x.(type) {
1071		case int64Val:
1072			return i64tor(x1), y
1073		case intVal:
1074			return itor(x1), y
1075		}
1076	case floatVal:
1077		switch x1 := x.(type) {
1078		case int64Val:
1079			return i64tof(x1), y
1080		case intVal:
1081			return itof(x1), y
1082		case ratVal:
1083			return rtof(x1), y
1084		}
1085	case complexVal:
1086		return vtoc(x), y
1087	}
1088
1089	// force unknown and invalid values into "x position" in callers of match
1090	// (don't panic here so that callers can provide a better error message)
1091	return x, x
1092}
1093
1094// BinaryOp returns the result of the binary expression x op y.
1095// The operation must be defined for the operands. If one of the
1096// operands is [Unknown], the result is [Unknown].
1097// BinaryOp doesn't handle comparisons or shifts; use [Compare]
1098// or [Shift] instead.
1099//
1100// To force integer division of [Int] operands, use op == [token.QUO_ASSIGN]
1101// instead of [token.QUO]; the result is guaranteed to be [Int] in this case.
1102// Division by zero leads to a run-time panic.
1103func BinaryOp(x_ Value, op token.Token, y_ Value) Value {
1104	x, y := match(x_, y_)
1105
1106	switch x := x.(type) {
1107	case unknownVal:
1108		return x
1109
1110	case boolVal:
1111		y := y.(boolVal)
1112		switch op {
1113		case token.LAND:
1114			return x && y
1115		case token.LOR:
1116			return x || y
1117		}
1118
1119	case int64Val:
1120		a := int64(x)
1121		b := int64(y.(int64Val))
1122		var c int64
1123		switch op {
1124		case token.ADD:
1125			if !is63bit(a) || !is63bit(b) {
1126				return makeInt(newInt().Add(big.NewInt(a), big.NewInt(b)))
1127			}
1128			c = a + b
1129		case token.SUB:
1130			if !is63bit(a) || !is63bit(b) {
1131				return makeInt(newInt().Sub(big.NewInt(a), big.NewInt(b)))
1132			}
1133			c = a - b
1134		case token.MUL:
1135			if !is32bit(a) || !is32bit(b) {
1136				return makeInt(newInt().Mul(big.NewInt(a), big.NewInt(b)))
1137			}
1138			c = a * b
1139		case token.QUO:
1140			return makeRat(big.NewRat(a, b))
1141		case token.QUO_ASSIGN: // force integer division
1142			c = a / b
1143		case token.REM:
1144			c = a % b
1145		case token.AND:
1146			c = a & b
1147		case token.OR:
1148			c = a | b
1149		case token.XOR:
1150			c = a ^ b
1151		case token.AND_NOT:
1152			c = a &^ b
1153		default:
1154			goto Error
1155		}
1156		return int64Val(c)
1157
1158	case intVal:
1159		a := x.val
1160		b := y.(intVal).val
1161		c := newInt()
1162		switch op {
1163		case token.ADD:
1164			c.Add(a, b)
1165		case token.SUB:
1166			c.Sub(a, b)
1167		case token.MUL:
1168			c.Mul(a, b)
1169		case token.QUO:
1170			return makeRat(newRat().SetFrac(a, b))
1171		case token.QUO_ASSIGN: // force integer division
1172			c.Quo(a, b)
1173		case token.REM:
1174			c.Rem(a, b)
1175		case token.AND:
1176			c.And(a, b)
1177		case token.OR:
1178			c.Or(a, b)
1179		case token.XOR:
1180			c.Xor(a, b)
1181		case token.AND_NOT:
1182			c.AndNot(a, b)
1183		default:
1184			goto Error
1185		}
1186		return makeInt(c)
1187
1188	case ratVal:
1189		a := x.val
1190		b := y.(ratVal).val
1191		c := newRat()
1192		switch op {
1193		case token.ADD:
1194			c.Add(a, b)
1195		case token.SUB:
1196			c.Sub(a, b)
1197		case token.MUL:
1198			c.Mul(a, b)
1199		case token.QUO:
1200			c.Quo(a, b)
1201		default:
1202			goto Error
1203		}
1204		return makeRat(c)
1205
1206	case floatVal:
1207		a := x.val
1208		b := y.(floatVal).val
1209		c := newFloat()
1210		switch op {
1211		case token.ADD:
1212			c.Add(a, b)
1213		case token.SUB:
1214			c.Sub(a, b)
1215		case token.MUL:
1216			c.Mul(a, b)
1217		case token.QUO:
1218			c.Quo(a, b)
1219		default:
1220			goto Error
1221		}
1222		return makeFloat(c)
1223
1224	case complexVal:
1225		y := y.(complexVal)
1226		a, b := x.re, x.im
1227		c, d := y.re, y.im
1228		var re, im Value
1229		switch op {
1230		case token.ADD:
1231			// (a+c) + i(b+d)
1232			re = add(a, c)
1233			im = add(b, d)
1234		case token.SUB:
1235			// (a-c) + i(b-d)
1236			re = sub(a, c)
1237			im = sub(b, d)
1238		case token.MUL:
1239			// (ac-bd) + i(bc+ad)
1240			ac := mul(a, c)
1241			bd := mul(b, d)
1242			bc := mul(b, c)
1243			ad := mul(a, d)
1244			re = sub(ac, bd)
1245			im = add(bc, ad)
1246		case token.QUO:
1247			// (ac+bd)/s + i(bc-ad)/s, with s = cc + dd
1248			ac := mul(a, c)
1249			bd := mul(b, d)
1250			bc := mul(b, c)
1251			ad := mul(a, d)
1252			cc := mul(c, c)
1253			dd := mul(d, d)
1254			s := add(cc, dd)
1255			re = add(ac, bd)
1256			re = quo(re, s)
1257			im = sub(bc, ad)
1258			im = quo(im, s)
1259		default:
1260			goto Error
1261		}
1262		return makeComplex(re, im)
1263
1264	case *stringVal:
1265		if op == token.ADD {
1266			return &stringVal{l: x, r: y.(*stringVal)}
1267		}
1268	}
1269
1270Error:
1271	panic(fmt.Sprintf("invalid binary operation %v %s %v", x_, op, y_))
1272}
1273
1274func add(x, y Value) Value { return BinaryOp(x, token.ADD, y) }
1275func sub(x, y Value) Value { return BinaryOp(x, token.SUB, y) }
1276func mul(x, y Value) Value { return BinaryOp(x, token.MUL, y) }
1277func quo(x, y Value) Value { return BinaryOp(x, token.QUO, y) }
1278
1279// Shift returns the result of the shift expression x op s
1280// with op == [token.SHL] or [token.SHR] (<< or >>). x must be
1281// an [Int] or an [Unknown]. If x is [Unknown], the result is x.
1282func Shift(x Value, op token.Token, s uint) Value {
1283	switch x := x.(type) {
1284	case unknownVal:
1285		return x
1286
1287	case int64Val:
1288		if s == 0 {
1289			return x
1290		}
1291		switch op {
1292		case token.SHL:
1293			z := i64toi(x).val
1294			return makeInt(z.Lsh(z, s))
1295		case token.SHR:
1296			return x >> s
1297		}
1298
1299	case intVal:
1300		if s == 0 {
1301			return x
1302		}
1303		z := newInt()
1304		switch op {
1305		case token.SHL:
1306			return makeInt(z.Lsh(x.val, s))
1307		case token.SHR:
1308			return makeInt(z.Rsh(x.val, s))
1309		}
1310	}
1311
1312	panic(fmt.Sprintf("invalid shift %v %s %d", x, op, s))
1313}
1314
1315func cmpZero(x int, op token.Token) bool {
1316	switch op {
1317	case token.EQL:
1318		return x == 0
1319	case token.NEQ:
1320		return x != 0
1321	case token.LSS:
1322		return x < 0
1323	case token.LEQ:
1324		return x <= 0
1325	case token.GTR:
1326		return x > 0
1327	case token.GEQ:
1328		return x >= 0
1329	}
1330	panic(fmt.Sprintf("invalid comparison %v %s 0", x, op))
1331}
1332
1333// Compare returns the result of the comparison x op y.
1334// The comparison must be defined for the operands.
1335// If one of the operands is [Unknown], the result is
1336// false.
1337func Compare(x_ Value, op token.Token, y_ Value) bool {
1338	x, y := match(x_, y_)
1339
1340	switch x := x.(type) {
1341	case unknownVal:
1342		return false
1343
1344	case boolVal:
1345		y := y.(boolVal)
1346		switch op {
1347		case token.EQL:
1348			return x == y
1349		case token.NEQ:
1350			return x != y
1351		}
1352
1353	case int64Val:
1354		y := y.(int64Val)
1355		switch op {
1356		case token.EQL:
1357			return x == y
1358		case token.NEQ:
1359			return x != y
1360		case token.LSS:
1361			return x < y
1362		case token.LEQ:
1363			return x <= y
1364		case token.GTR:
1365			return x > y
1366		case token.GEQ:
1367			return x >= y
1368		}
1369
1370	case intVal:
1371		return cmpZero(x.val.Cmp(y.(intVal).val), op)
1372
1373	case ratVal:
1374		return cmpZero(x.val.Cmp(y.(ratVal).val), op)
1375
1376	case floatVal:
1377		return cmpZero(x.val.Cmp(y.(floatVal).val), op)
1378
1379	case complexVal:
1380		y := y.(complexVal)
1381		re := Compare(x.re, token.EQL, y.re)
1382		im := Compare(x.im, token.EQL, y.im)
1383		switch op {
1384		case token.EQL:
1385			return re && im
1386		case token.NEQ:
1387			return !re || !im
1388		}
1389
1390	case *stringVal:
1391		xs := x.string()
1392		ys := y.(*stringVal).string()
1393		switch op {
1394		case token.EQL:
1395			return xs == ys
1396		case token.NEQ:
1397			return xs != ys
1398		case token.LSS:
1399			return xs < ys
1400		case token.LEQ:
1401			return xs <= ys
1402		case token.GTR:
1403			return xs > ys
1404		case token.GEQ:
1405			return xs >= ys
1406		}
1407	}
1408
1409	panic(fmt.Sprintf("invalid comparison %v %s %v", x_, op, y_))
1410}
1411