1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package typeparams
6
7import (
8	"errors"
9	"fmt"
10	"go/types"
11	"os"
12	"strings"
13)
14
15//go:generate go run copytermlist.go
16
17const debug = false
18
19var ErrEmptyTypeSet = errors.New("empty type set")
20
21// StructuralTerms returns a slice of terms representing the normalized
22// structural type restrictions of a type parameter, if any.
23//
24// Structural type restrictions of a type parameter are created via
25// non-interface types embedded in its constraint interface (directly, or via a
26// chain of interface embeddings). For example, in the declaration
27//
28//	type T[P interface{~int; m()}] int
29//
30// the structural restriction of the type parameter P is ~int.
31//
32// With interface embedding and unions, the specification of structural type
33// restrictions may be arbitrarily complex. For example, consider the
34// following:
35//
36//	type A interface{ ~string|~[]byte }
37//
38//	type B interface{ int|string }
39//
40//	type C interface { ~string|~int }
41//
42//	type T[P interface{ A|B; C }] int
43//
44// In this example, the structural type restriction of P is ~string|int: A|B
45// expands to ~string|~[]byte|int|string, which reduces to ~string|~[]byte|int,
46// which when intersected with C (~string|~int) yields ~string|int.
47//
48// StructuralTerms computes these expansions and reductions, producing a
49// "normalized" form of the embeddings. A structural restriction is normalized
50// if it is a single union containing no interface terms, and is minimal in the
51// sense that removing any term changes the set of types satisfying the
52// constraint. It is left as a proof for the reader that, modulo sorting, there
53// is exactly one such normalized form.
54//
55// Because the minimal representation always takes this form, StructuralTerms
56// returns a slice of tilde terms corresponding to the terms of the union in
57// the normalized structural restriction. An error is returned if the
58// constraint interface is invalid, exceeds complexity bounds, or has an empty
59// type set. In the latter case, StructuralTerms returns ErrEmptyTypeSet.
60//
61// StructuralTerms makes no guarantees about the order of terms, except that it
62// is deterministic.
63func StructuralTerms(tparam *types.TypeParam) ([]*types.Term, error) {
64	constraint := tparam.Constraint()
65	if constraint == nil {
66		return nil, fmt.Errorf("%s has nil constraint", tparam)
67	}
68	iface, _ := constraint.Underlying().(*types.Interface)
69	if iface == nil {
70		return nil, fmt.Errorf("constraint is %T, not *types.Interface", constraint.Underlying())
71	}
72	return InterfaceTermSet(iface)
73}
74
75// InterfaceTermSet computes the normalized terms for a constraint interface,
76// returning an error if the term set cannot be computed or is empty. In the
77// latter case, the error will be ErrEmptyTypeSet.
78//
79// See the documentation of StructuralTerms for more information on
80// normalization.
81func InterfaceTermSet(iface *types.Interface) ([]*types.Term, error) {
82	return computeTermSet(iface)
83}
84
85// UnionTermSet computes the normalized terms for a union, returning an error
86// if the term set cannot be computed or is empty. In the latter case, the
87// error will be ErrEmptyTypeSet.
88//
89// See the documentation of StructuralTerms for more information on
90// normalization.
91func UnionTermSet(union *types.Union) ([]*types.Term, error) {
92	return computeTermSet(union)
93}
94
95func computeTermSet(typ types.Type) ([]*types.Term, error) {
96	tset, err := computeTermSetInternal(typ, make(map[types.Type]*termSet), 0)
97	if err != nil {
98		return nil, err
99	}
100	if tset.terms.isEmpty() {
101		return nil, ErrEmptyTypeSet
102	}
103	if tset.terms.isAll() {
104		return nil, nil
105	}
106	var terms []*types.Term
107	for _, term := range tset.terms {
108		terms = append(terms, types.NewTerm(term.tilde, term.typ))
109	}
110	return terms, nil
111}
112
113// A termSet holds the normalized set of terms for a given type.
114//
115// The name termSet is intentionally distinct from 'type set': a type set is
116// all types that implement a type (and includes method restrictions), whereas
117// a term set just represents the structural restrictions on a type.
118type termSet struct {
119	complete bool
120	terms    termlist
121}
122
123func indentf(depth int, format string, args ...interface{}) {
124	fmt.Fprintf(os.Stderr, strings.Repeat(".", depth)+format+"\n", args...)
125}
126
127func computeTermSetInternal(t types.Type, seen map[types.Type]*termSet, depth int) (res *termSet, err error) {
128	if t == nil {
129		panic("nil type")
130	}
131
132	if debug {
133		indentf(depth, "%s", t.String())
134		defer func() {
135			if err != nil {
136				indentf(depth, "=> %s", err)
137			} else {
138				indentf(depth, "=> %s", res.terms.String())
139			}
140		}()
141	}
142
143	const maxTermCount = 100
144	if tset, ok := seen[t]; ok {
145		if !tset.complete {
146			return nil, fmt.Errorf("cycle detected in the declaration of %s", t)
147		}
148		return tset, nil
149	}
150
151	// Mark the current type as seen to avoid infinite recursion.
152	tset := new(termSet)
153	defer func() {
154		tset.complete = true
155	}()
156	seen[t] = tset
157
158	switch u := t.Underlying().(type) {
159	case *types.Interface:
160		// The term set of an interface is the intersection of the term sets of its
161		// embedded types.
162		tset.terms = allTermlist
163		for i := 0; i < u.NumEmbeddeds(); i++ {
164			embedded := u.EmbeddedType(i)
165			if _, ok := embedded.Underlying().(*types.TypeParam); ok {
166				return nil, fmt.Errorf("invalid embedded type %T", embedded)
167			}
168			tset2, err := computeTermSetInternal(embedded, seen, depth+1)
169			if err != nil {
170				return nil, err
171			}
172			tset.terms = tset.terms.intersect(tset2.terms)
173		}
174	case *types.Union:
175		// The term set of a union is the union of term sets of its terms.
176		tset.terms = nil
177		for i := 0; i < u.Len(); i++ {
178			t := u.Term(i)
179			var terms termlist
180			switch t.Type().Underlying().(type) {
181			case *types.Interface:
182				tset2, err := computeTermSetInternal(t.Type(), seen, depth+1)
183				if err != nil {
184					return nil, err
185				}
186				terms = tset2.terms
187			case *types.TypeParam, *types.Union:
188				// A stand-alone type parameter or union is not permitted as union
189				// term.
190				return nil, fmt.Errorf("invalid union term %T", t)
191			default:
192				if t.Type() == types.Typ[types.Invalid] {
193					continue
194				}
195				terms = termlist{{t.Tilde(), t.Type()}}
196			}
197			tset.terms = tset.terms.union(terms)
198			if len(tset.terms) > maxTermCount {
199				return nil, fmt.Errorf("exceeded max term count %d", maxTermCount)
200			}
201		}
202	case *types.TypeParam:
203		panic("unreachable")
204	default:
205		// For all other types, the term set is just a single non-tilde term
206		// holding the type itself.
207		if u != types.Typ[types.Invalid] {
208			tset.terms = termlist{{false, t}}
209		}
210	}
211	return tset, nil
212}
213
214// under is a facade for the go/types internal function of the same name. It is
215// used by typeterm.go.
216func under(t types.Type) types.Type {
217	return t.Underlying()
218}
219