xref: /aosp_15_r20/external/starlark-go/lib/proto/proto.go (revision 4947cdc739c985f6d86941e22894f5cefe7c9e9a)
1*4947cdc7SCole Faust// Copyright 2020 The Bazel Authors. All rights reserved.
2*4947cdc7SCole Faust// Use of this source code is governed by a BSD-style
3*4947cdc7SCole Faust// license that can be found in the LICENSE file.
4*4947cdc7SCole Faust
5*4947cdc7SCole Faust// Package proto defines a module of utilities for constructing and
6*4947cdc7SCole Faust// accessing protocol messages within Starlark programs.
7*4947cdc7SCole Faust//
8*4947cdc7SCole Faust// THIS PACKAGE IS EXPERIMENTAL AND ITS INTERFACE MAY CHANGE.
9*4947cdc7SCole Faust//
10*4947cdc7SCole Faust// This package defines several types of Starlark value:
11*4947cdc7SCole Faust//
12*4947cdc7SCole Faust//      Message                 -- a protocol message
13*4947cdc7SCole Faust//      RepeatedField           -- a repeated field of a message, like a list
14*4947cdc7SCole Faust//
15*4947cdc7SCole Faust//      FileDescriptor          -- information about a .proto file
16*4947cdc7SCole Faust//      FieldDescriptor         -- information about a message field (or extension field)
17*4947cdc7SCole Faust//      MessageDescriptor       -- information about the type of a message
18*4947cdc7SCole Faust//      EnumDescriptor          -- information about an enumerated type
19*4947cdc7SCole Faust//      EnumValueDescriptor     -- a value of an enumerated type
20*4947cdc7SCole Faust//
21*4947cdc7SCole Faust// A Message value is a wrapper around a protocol message instance.
22*4947cdc7SCole Faust// Starlark programs may access and update Messages using dot notation:
23*4947cdc7SCole Faust//
24*4947cdc7SCole Faust//      x = msg.field
25*4947cdc7SCole Faust//      msg.field = x + 1
26*4947cdc7SCole Faust//      msg.field += 1
27*4947cdc7SCole Faust//
28*4947cdc7SCole Faust// Assignments to message fields perform dynamic checks on the type and
29*4947cdc7SCole Faust// range of the value to ensure that the message is at all times valid.
30*4947cdc7SCole Faust//
31*4947cdc7SCole Faust// The value of a repeated field of a message is represented by the
32*4947cdc7SCole Faust// list-like data type, RepeatedField.  Its elements may be accessed,
33*4947cdc7SCole Faust// iterated, and updated in the usual ways.  As with assignments to
34*4947cdc7SCole Faust// message fields, an assignment to an element of a RepeatedField
35*4947cdc7SCole Faust// performs a dynamic check to ensure that the RepeatedField holds
36*4947cdc7SCole Faust// only elements of the correct type.
37*4947cdc7SCole Faust//
38*4947cdc7SCole Faust//      type(msg.uint32s)       # "proto.repeated<uint32>"
39*4947cdc7SCole Faust//      msg.uint32s[0] = 1
40*4947cdc7SCole Faust//      msg.uint32s[0] = -1     # error: invalid uint32: -1
41*4947cdc7SCole Faust//
42*4947cdc7SCole Faust// Any iterable may be assigned to a repeated field of a message.  If
43*4947cdc7SCole Faust// the iterable is itself a value of type RepeatedField, the message
44*4947cdc7SCole Faust// field holds a reference to it.
45*4947cdc7SCole Faust//
46*4947cdc7SCole Faust//      msg2.uint32s = msg.uint32s      # both messages share one RepeatedField
47*4947cdc7SCole Faust//      msg.uint32s[0] = 123
48*4947cdc7SCole Faust//      print(msg2.uint32s[0])          # "123"
49*4947cdc7SCole Faust//
50*4947cdc7SCole Faust// The RepeatedFields' element types must match.
51*4947cdc7SCole Faust// It is not enough for the values to be merely valid:
52*4947cdc7SCole Faust//
53*4947cdc7SCole Faust//      msg.uint32s = [1, 2, 3]         # makes a copy
54*4947cdc7SCole Faust//      msg.uint64s = msg.uint32s       # error: repeated field has wrong type
55*4947cdc7SCole Faust//      msg.uint64s = list(msg.uint32s) # ok; makes a copy
56*4947cdc7SCole Faust//
57*4947cdc7SCole Faust// For all other iterables, a new RepeatedField is constructed from the
58*4947cdc7SCole Faust// elements of the iterable.
59*4947cdc7SCole Faust//
60*4947cdc7SCole Faust//      msg.uints32s = [1, 2, 3]
61*4947cdc7SCole Faust//      print(type(msg.uints32s))       # "proto.repeated<uint32>"
62*4947cdc7SCole Faust//
63*4947cdc7SCole Faust//
64*4947cdc7SCole Faust// To construct a Message from encoded binary or text data, call
65*4947cdc7SCole Faust// Unmarshal or UnmarshalText.  These two functions are exposed to
66*4947cdc7SCole Faust// Starlark programs as proto.unmarshal{,_text}.
67*4947cdc7SCole Faust//
68*4947cdc7SCole Faust// To construct a Message from an existing Go proto.Message instance,
69*4947cdc7SCole Faust// you must first encode the Go message to binary, then decode it using
70*4947cdc7SCole Faust// Unmarshal. This ensures that messages visible to Starlark are
71*4947cdc7SCole Faust// encapsulated and cannot be mutated once their Starlark wrapper values
72*4947cdc7SCole Faust// are frozen.
73*4947cdc7SCole Faust//
74*4947cdc7SCole Faust// TODO(adonovan): document descriptors, enums, message instantiation.
75*4947cdc7SCole Faust//
76*4947cdc7SCole Faust// See proto_test.go for an example of how to use the 'proto'
77*4947cdc7SCole Faust// module in an application that embeds Starlark.
78*4947cdc7SCole Faust//
79*4947cdc7SCole Faustpackage proto
80*4947cdc7SCole Faust
81*4947cdc7SCole Faust// TODO(adonovan): Go and Starlark API improvements:
82*4947cdc7SCole Faust// - Make Message and RepeatedField comparable.
83*4947cdc7SCole Faust//   (NOTE: proto.Equal works only with generated message types.)
84*4947cdc7SCole Faust// - Support maps, oneof, any. But not messageset if we can avoid it.
85*4947cdc7SCole Faust// - Support "well-known types".
86*4947cdc7SCole Faust// - Defend against cycles in object graph.
87*4947cdc7SCole Faust// - Test missing required fields in marshalling.
88*4947cdc7SCole Faust
89*4947cdc7SCole Faustimport (
90*4947cdc7SCole Faust	"bytes"
91*4947cdc7SCole Faust	"fmt"
92*4947cdc7SCole Faust	"sort"
93*4947cdc7SCole Faust	"strings"
94*4947cdc7SCole Faust	"unsafe"
95*4947cdc7SCole Faust	_ "unsafe" // for linkname hack
96*4947cdc7SCole Faust
97*4947cdc7SCole Faust	"google.golang.org/protobuf/encoding/prototext"
98*4947cdc7SCole Faust	"google.golang.org/protobuf/proto"
99*4947cdc7SCole Faust	"google.golang.org/protobuf/reflect/protoreflect"
100*4947cdc7SCole Faust	"google.golang.org/protobuf/reflect/protoregistry"
101*4947cdc7SCole Faust	"google.golang.org/protobuf/types/dynamicpb"
102*4947cdc7SCole Faust
103*4947cdc7SCole Faust	"go.starlark.net/starlark"
104*4947cdc7SCole Faust	"go.starlark.net/starlarkstruct"
105*4947cdc7SCole Faust	"go.starlark.net/syntax"
106*4947cdc7SCole Faust)
107*4947cdc7SCole Faust
108*4947cdc7SCole Faust// SetPool associates with the specified Starlark thread the
109*4947cdc7SCole Faust// descriptor pool used to find descriptors for .proto files and to
110*4947cdc7SCole Faust// instantiate messages from descriptors.  Clients must call SetPool
111*4947cdc7SCole Faust// for a Starlark thread to use this package.
112*4947cdc7SCole Faust//
113*4947cdc7SCole Faust// For example:
114*4947cdc7SCole Faust//	SetPool(thread, protoregistry.GlobalFiles)
115*4947cdc7SCole Faust//
116*4947cdc7SCole Faustfunc SetPool(thread *starlark.Thread, pool DescriptorPool) {
117*4947cdc7SCole Faust	thread.SetLocal(contextKey, pool)
118*4947cdc7SCole Faust}
119*4947cdc7SCole Faust
120*4947cdc7SCole Faust// Pool returns the descriptor pool previously associated with this thread.
121*4947cdc7SCole Faustfunc Pool(thread *starlark.Thread) DescriptorPool {
122*4947cdc7SCole Faust	pool, _ := thread.Local(contextKey).(DescriptorPool)
123*4947cdc7SCole Faust	return pool
124*4947cdc7SCole Faust}
125*4947cdc7SCole Faust
126*4947cdc7SCole Faustconst contextKey = "proto.DescriptorPool"
127*4947cdc7SCole Faust
128*4947cdc7SCole Faust// A DescriptorPool loads FileDescriptors by path name or package name,
129*4947cdc7SCole Faust// possibly on demand.
130*4947cdc7SCole Faust//
131*4947cdc7SCole Faust// It is a superinterface of protodesc.Resolver, so any Resolver
132*4947cdc7SCole Faust// implementation is a valid pool. For example.
133*4947cdc7SCole Faust// protoregistry.GlobalFiles, which loads FileDescriptors from the
134*4947cdc7SCole Faust// compressed binary information in all the *.pb.go files linked into
135*4947cdc7SCole Faust// the process; and protodesc.NewFiles, which holds a set of
136*4947cdc7SCole Faust// FileDescriptorSet messages. See star2proto for example usage.
137*4947cdc7SCole Fausttype DescriptorPool interface {
138*4947cdc7SCole Faust	FindFileByPath(string) (protoreflect.FileDescriptor, error)
139*4947cdc7SCole Faust}
140*4947cdc7SCole Faust
141*4947cdc7SCole Faustvar Module = &starlarkstruct.Module{
142*4947cdc7SCole Faust	Name: "proto",
143*4947cdc7SCole Faust	Members: starlark.StringDict{
144*4947cdc7SCole Faust		"file":           starlark.NewBuiltin("proto.file", file),
145*4947cdc7SCole Faust		"has":            starlark.NewBuiltin("proto.has", has),
146*4947cdc7SCole Faust		"marshal":        starlark.NewBuiltin("proto.marshal", marshal),
147*4947cdc7SCole Faust		"marshal_text":   starlark.NewBuiltin("proto.marshal_text", marshal),
148*4947cdc7SCole Faust		"set_field":      starlark.NewBuiltin("proto.set_field", setFieldStarlark),
149*4947cdc7SCole Faust		"get_field":      starlark.NewBuiltin("proto.get_field", getFieldStarlark),
150*4947cdc7SCole Faust		"unmarshal":      starlark.NewBuiltin("proto.unmarshal", unmarshal),
151*4947cdc7SCole Faust		"unmarshal_text": starlark.NewBuiltin("proto.unmarshal_text", unmarshal_text),
152*4947cdc7SCole Faust
153*4947cdc7SCole Faust		// TODO(adonovan):
154*4947cdc7SCole Faust		// - merge(msg, msg) -> msg
155*4947cdc7SCole Faust		// - equals(msg, msg) -> bool
156*4947cdc7SCole Faust		// - diff(msg, msg) -> string
157*4947cdc7SCole Faust		// - clone(msg) -> msg
158*4947cdc7SCole Faust	},
159*4947cdc7SCole Faust}
160*4947cdc7SCole Faust
161*4947cdc7SCole Faust// file(filename) loads the FileDescriptor of the given name, or the
162*4947cdc7SCole Faust// first if the pool contains more than one.
163*4947cdc7SCole Faust//
164*4947cdc7SCole Faust// It's unfortunate that renaming a .proto file in effect breaks the
165*4947cdc7SCole Faust// interface it presents to Starlark. Ideally one would import
166*4947cdc7SCole Faust// descriptors by package name, but there may be many FileDescriptors
167*4947cdc7SCole Faust// for the same package name, and there is no "package descriptor".
168*4947cdc7SCole Faust// (Technically a pool may also have many FileDescriptors with the same
169*4947cdc7SCole Faust// file name, but this can't happen with a single consistent snapshot.)
170*4947cdc7SCole Faustfunc file(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
171*4947cdc7SCole Faust	var filename string
172*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &filename); err != nil {
173*4947cdc7SCole Faust		return nil, err
174*4947cdc7SCole Faust	}
175*4947cdc7SCole Faust
176*4947cdc7SCole Faust	pool := Pool(thread)
177*4947cdc7SCole Faust	if pool == nil {
178*4947cdc7SCole Faust		return nil, fmt.Errorf("internal error: SetPool was not called")
179*4947cdc7SCole Faust	}
180*4947cdc7SCole Faust
181*4947cdc7SCole Faust	desc, err := pool.FindFileByPath(filename)
182*4947cdc7SCole Faust	if err != nil {
183*4947cdc7SCole Faust		return nil, err
184*4947cdc7SCole Faust	}
185*4947cdc7SCole Faust
186*4947cdc7SCole Faust	return FileDescriptor{Desc: desc}, nil
187*4947cdc7SCole Faust}
188*4947cdc7SCole Faust
189*4947cdc7SCole Faust// has(msg, field) reports whether the specified field of the message is present.
190*4947cdc7SCole Faust// A field may be specified by name (string) or FieldDescriptor.
191*4947cdc7SCole Faust// has reports an error if the message type has no such field.
192*4947cdc7SCole Faustfunc has(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
193*4947cdc7SCole Faust	var x, field starlark.Value
194*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 2, &x, &field); err != nil {
195*4947cdc7SCole Faust		return nil, err
196*4947cdc7SCole Faust	}
197*4947cdc7SCole Faust	msg, ok := x.(*Message)
198*4947cdc7SCole Faust	if !ok {
199*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: got %s, want proto.Message", fn.Name(), x.Type())
200*4947cdc7SCole Faust	}
201*4947cdc7SCole Faust
202*4947cdc7SCole Faust	var fdesc protoreflect.FieldDescriptor
203*4947cdc7SCole Faust	switch field := field.(type) {
204*4947cdc7SCole Faust	case starlark.String:
205*4947cdc7SCole Faust		var err error
206*4947cdc7SCole Faust		fdesc, err = fieldDesc(msg.desc(), string(field))
207*4947cdc7SCole Faust		if err != nil {
208*4947cdc7SCole Faust			return nil, err
209*4947cdc7SCole Faust		}
210*4947cdc7SCole Faust
211*4947cdc7SCole Faust	case FieldDescriptor:
212*4947cdc7SCole Faust		if field.Desc.ContainingMessage() != msg.desc() {
213*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: %v does not have field %v", fn.Name(), msg.desc().FullName(), field)
214*4947cdc7SCole Faust		}
215*4947cdc7SCole Faust		fdesc = field.Desc
216*4947cdc7SCole Faust
217*4947cdc7SCole Faust	default:
218*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: for field argument, got %s, want string or proto.FieldDescriptor", fn.Name(), field.Type())
219*4947cdc7SCole Faust	}
220*4947cdc7SCole Faust
221*4947cdc7SCole Faust	return starlark.Bool(msg.msg.Has(fdesc)), nil
222*4947cdc7SCole Faust}
223*4947cdc7SCole Faust
224*4947cdc7SCole Faust// marshal{,_text}(msg) encodes a Message value to binary or text form.
225*4947cdc7SCole Faustfunc marshal(_ *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
226*4947cdc7SCole Faust	var m *Message
227*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &m); err != nil {
228*4947cdc7SCole Faust		return nil, err
229*4947cdc7SCole Faust	}
230*4947cdc7SCole Faust	if fn.Name() == "proto.marshal" {
231*4947cdc7SCole Faust		data, err := proto.Marshal(m.Message())
232*4947cdc7SCole Faust		if err != nil {
233*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: %v", fn.Name(), err)
234*4947cdc7SCole Faust		}
235*4947cdc7SCole Faust		return starlark.Bytes(data), nil
236*4947cdc7SCole Faust	} else {
237*4947cdc7SCole Faust		text, err := prototext.MarshalOptions{Indent: "  "}.Marshal(m.Message())
238*4947cdc7SCole Faust		if err != nil {
239*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: %v", fn.Name(), err)
240*4947cdc7SCole Faust		}
241*4947cdc7SCole Faust		return starlark.String(text), nil
242*4947cdc7SCole Faust	}
243*4947cdc7SCole Faust}
244*4947cdc7SCole Faust
245*4947cdc7SCole Faust// unmarshal(msg) decodes a binary protocol message to a Message.
246*4947cdc7SCole Faustfunc unmarshal(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
247*4947cdc7SCole Faust	var desc MessageDescriptor
248*4947cdc7SCole Faust	var data starlark.Bytes
249*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 2, &desc, &data); err != nil {
250*4947cdc7SCole Faust		return nil, err
251*4947cdc7SCole Faust	}
252*4947cdc7SCole Faust	return unmarshalData(desc.Desc, []byte(data), true)
253*4947cdc7SCole Faust}
254*4947cdc7SCole Faust
255*4947cdc7SCole Faust// unmarshal_text(msg) decodes a text protocol message to a Message.
256*4947cdc7SCole Faustfunc unmarshal_text(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
257*4947cdc7SCole Faust	var desc MessageDescriptor
258*4947cdc7SCole Faust	var data string
259*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 2, &desc, &data); err != nil {
260*4947cdc7SCole Faust		return nil, err
261*4947cdc7SCole Faust	}
262*4947cdc7SCole Faust	return unmarshalData(desc.Desc, []byte(data), false)
263*4947cdc7SCole Faust}
264*4947cdc7SCole Faust
265*4947cdc7SCole Faust// set_field(msg, field, value) updates the value of a field.
266*4947cdc7SCole Faust// It is typically used for extensions, which cannot be updated using msg.field = v notation.
267*4947cdc7SCole Faustfunc setFieldStarlark(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
268*4947cdc7SCole Faust	// TODO(adonovan): allow field to be specified by name (for non-extension fields), like has?
269*4947cdc7SCole Faust	var m *Message
270*4947cdc7SCole Faust	var field FieldDescriptor
271*4947cdc7SCole Faust	var v starlark.Value
272*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 3, &m, &field, &v); err != nil {
273*4947cdc7SCole Faust		return nil, err
274*4947cdc7SCole Faust	}
275*4947cdc7SCole Faust
276*4947cdc7SCole Faust	if *m.frozen {
277*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: cannot set %v field of frozen %v message", fn.Name(), field, m.desc().FullName())
278*4947cdc7SCole Faust	}
279*4947cdc7SCole Faust
280*4947cdc7SCole Faust	if field.Desc.ContainingMessage() != m.desc() {
281*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: %v does not have field %v", fn.Name(), m.desc().FullName(), field)
282*4947cdc7SCole Faust	}
283*4947cdc7SCole Faust
284*4947cdc7SCole Faust	return starlark.None, setField(m.msg, field.Desc, v)
285*4947cdc7SCole Faust}
286*4947cdc7SCole Faust
287*4947cdc7SCole Faust// get_field(msg, field) retrieves the value of a field.
288*4947cdc7SCole Faust// It is typically used for extension fields, which cannot be accessed using msg.field notation.
289*4947cdc7SCole Faustfunc getFieldStarlark(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
290*4947cdc7SCole Faust	// TODO(adonovan): allow field to be specified by name (for non-extension fields), like has?
291*4947cdc7SCole Faust	var msg *Message
292*4947cdc7SCole Faust	var field FieldDescriptor
293*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 2, &msg, &field); err != nil {
294*4947cdc7SCole Faust		return nil, err
295*4947cdc7SCole Faust	}
296*4947cdc7SCole Faust
297*4947cdc7SCole Faust	if field.Desc.ContainingMessage() != msg.desc() {
298*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: %v does not have field %v", fn.Name(), msg.desc().FullName(), field)
299*4947cdc7SCole Faust	}
300*4947cdc7SCole Faust
301*4947cdc7SCole Faust	return msg.getField(field.Desc), nil
302*4947cdc7SCole Faust}
303*4947cdc7SCole Faust
304*4947cdc7SCole Faust// The Call method implements the starlark.Callable interface.
305*4947cdc7SCole Faust// When a message descriptor is called, it returns a new instance of the
306*4947cdc7SCole Faust// protocol message it describes.
307*4947cdc7SCole Faust//
308*4947cdc7SCole Faust//      Message(msg)            -- return a shallow copy of an existing message
309*4947cdc7SCole Faust//      Message(k=v, ...)       -- return a new message with the specified fields
310*4947cdc7SCole Faust//      Message(dict(...))      -- return a new message with the specified fields
311*4947cdc7SCole Faust//
312*4947cdc7SCole Faustfunc (d MessageDescriptor) CallInternal(thread *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
313*4947cdc7SCole Faust	dest := &Message{
314*4947cdc7SCole Faust		msg:    newMessage(d.Desc),
315*4947cdc7SCole Faust		frozen: new(bool),
316*4947cdc7SCole Faust	}
317*4947cdc7SCole Faust
318*4947cdc7SCole Faust	// Single positional argument?
319*4947cdc7SCole Faust	if len(args) > 0 {
320*4947cdc7SCole Faust		if len(kwargs) > 0 {
321*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: got both positional and named arguments", d.Desc.Name())
322*4947cdc7SCole Faust		}
323*4947cdc7SCole Faust		if len(args) > 1 {
324*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: got %d positional arguments, want at most 1", d.Desc.Name(), len(args))
325*4947cdc7SCole Faust		}
326*4947cdc7SCole Faust
327*4947cdc7SCole Faust		// Keep consistent with MessageKind case of toProto.
328*4947cdc7SCole Faust		// (support the same argument types).
329*4947cdc7SCole Faust		switch src := args[0].(type) {
330*4947cdc7SCole Faust		case *Message:
331*4947cdc7SCole Faust			if dest.desc() != src.desc() {
332*4947cdc7SCole Faust				return nil, fmt.Errorf("%s: got message of type %s, want type %s", d.Desc.Name(), src.desc().FullName(), dest.desc().FullName())
333*4947cdc7SCole Faust			}
334*4947cdc7SCole Faust
335*4947cdc7SCole Faust			// Make shallow copy of message.
336*4947cdc7SCole Faust			// TODO(adonovan): How does frozen work if we have shallow copy?
337*4947cdc7SCole Faust			src.msg.Range(func(fdesc protoreflect.FieldDescriptor, v protoreflect.Value) bool {
338*4947cdc7SCole Faust				dest.msg.Set(fdesc, v)
339*4947cdc7SCole Faust				return true
340*4947cdc7SCole Faust			})
341*4947cdc7SCole Faust			return dest, nil
342*4947cdc7SCole Faust
343*4947cdc7SCole Faust		case *starlark.Dict:
344*4947cdc7SCole Faust			kwargs = src.Items()
345*4947cdc7SCole Faust			// fall through
346*4947cdc7SCole Faust
347*4947cdc7SCole Faust		default:
348*4947cdc7SCole Faust			return nil, fmt.Errorf("%s: got %s, want dict or message", d.Desc.Name(), src.Type())
349*4947cdc7SCole Faust		}
350*4947cdc7SCole Faust	}
351*4947cdc7SCole Faust
352*4947cdc7SCole Faust	// Convert named arguments to field values.
353*4947cdc7SCole Faust	err := setFields(dest.msg, kwargs)
354*4947cdc7SCole Faust	return dest, err
355*4947cdc7SCole Faust}
356*4947cdc7SCole Faust
357*4947cdc7SCole Faust// setFields updates msg as if by msg.name=value for each (name, value) in items.
358*4947cdc7SCole Faustfunc setFields(msg protoreflect.Message, items []starlark.Tuple) error {
359*4947cdc7SCole Faust	for _, item := range items {
360*4947cdc7SCole Faust		name, ok := starlark.AsString(item[0])
361*4947cdc7SCole Faust		if !ok {
362*4947cdc7SCole Faust			return fmt.Errorf("got %s, want string", item[0].Type())
363*4947cdc7SCole Faust		}
364*4947cdc7SCole Faust		fdesc, err := fieldDesc(msg.Descriptor(), name)
365*4947cdc7SCole Faust		if err != nil {
366*4947cdc7SCole Faust			return err
367*4947cdc7SCole Faust		}
368*4947cdc7SCole Faust		if err := setField(msg, fdesc, item[1]); err != nil {
369*4947cdc7SCole Faust			return err
370*4947cdc7SCole Faust		}
371*4947cdc7SCole Faust	}
372*4947cdc7SCole Faust	return nil
373*4947cdc7SCole Faust}
374*4947cdc7SCole Faust
375*4947cdc7SCole Faust// setField validates a Starlark field value, converts it to canonical form,
376*4947cdc7SCole Faust// and assigns to the field of msg.  If value is None, the field is unset.
377*4947cdc7SCole Faustfunc setField(msg protoreflect.Message, fdesc protoreflect.FieldDescriptor, value starlark.Value) error {
378*4947cdc7SCole Faust	// None unsets a field.
379*4947cdc7SCole Faust	if value == starlark.None {
380*4947cdc7SCole Faust		msg.Clear(fdesc)
381*4947cdc7SCole Faust		return nil
382*4947cdc7SCole Faust	}
383*4947cdc7SCole Faust
384*4947cdc7SCole Faust	// Assigning to a repeated field must make a copy,
385*4947cdc7SCole Faust	// because the fields.Set doesn't specify whether
386*4947cdc7SCole Faust	// it aliases the list or not, so we cannot assume.
387*4947cdc7SCole Faust	//
388*4947cdc7SCole Faust	// This is potentially surprising as
389*4947cdc7SCole Faust	//  x = []; msg.x = x; y = msg.x
390*4947cdc7SCole Faust	// causes x and y not to alias.
391*4947cdc7SCole Faust	if fdesc.IsList() {
392*4947cdc7SCole Faust		iter := starlark.Iterate(value)
393*4947cdc7SCole Faust		if iter == nil {
394*4947cdc7SCole Faust			return fmt.Errorf("got %s for .%s field, want iterable", value.Type(), fdesc.Name())
395*4947cdc7SCole Faust		}
396*4947cdc7SCole Faust		defer iter.Done()
397*4947cdc7SCole Faust
398*4947cdc7SCole Faust		// TODO(adonovan): handle maps
399*4947cdc7SCole Faust		list := msg.Mutable(fdesc).List()
400*4947cdc7SCole Faust		var x starlark.Value
401*4947cdc7SCole Faust		for i := 0; iter.Next(&x); i++ {
402*4947cdc7SCole Faust			v, err := toProto(fdesc, x)
403*4947cdc7SCole Faust			if err != nil {
404*4947cdc7SCole Faust				return fmt.Errorf("index %d: %v", i, err)
405*4947cdc7SCole Faust			}
406*4947cdc7SCole Faust			list.Append(v)
407*4947cdc7SCole Faust		}
408*4947cdc7SCole Faust		return nil
409*4947cdc7SCole Faust	}
410*4947cdc7SCole Faust
411*4947cdc7SCole Faust	v, err := toProto(fdesc, value)
412*4947cdc7SCole Faust	if err != nil {
413*4947cdc7SCole Faust		return fmt.Errorf("in field %s: %v", fdesc.Name(), err)
414*4947cdc7SCole Faust	}
415*4947cdc7SCole Faust
416*4947cdc7SCole Faust	if fdesc.IsExtension() {
417*4947cdc7SCole Faust		// The protoreflect.Message.NewField method must be able
418*4947cdc7SCole Faust		// to return a new instance of the field type. Without
419*4947cdc7SCole Faust		// having the Go type information available for extensions,
420*4947cdc7SCole Faust		// the implementation of NewField won't know what to do.
421*4947cdc7SCole Faust		//
422*4947cdc7SCole Faust		// Thus we must augment the FieldDescriptor to one that
423*4947cdc7SCole Faust		// additional holds Go representation type information
424*4947cdc7SCole Faust		// (based in this case on dynamicpb).
425*4947cdc7SCole Faust		fdesc = dynamicpb.NewExtensionType(fdesc).TypeDescriptor()
426*4947cdc7SCole Faust		_ = fdesc.(protoreflect.ExtensionTypeDescriptor)
427*4947cdc7SCole Faust	}
428*4947cdc7SCole Faust
429*4947cdc7SCole Faust	msg.Set(fdesc, v)
430*4947cdc7SCole Faust	return nil
431*4947cdc7SCole Faust}
432*4947cdc7SCole Faust
433*4947cdc7SCole Faust// toProto converts a Starlark value for a message field into protoreflect form.
434*4947cdc7SCole Faustfunc toProto(fdesc protoreflect.FieldDescriptor, v starlark.Value) (protoreflect.Value, error) {
435*4947cdc7SCole Faust	switch fdesc.Kind() {
436*4947cdc7SCole Faust	case protoreflect.BoolKind:
437*4947cdc7SCole Faust		// To avoid mistakes, we require v be exactly a bool.
438*4947cdc7SCole Faust		if v, ok := v.(starlark.Bool); ok {
439*4947cdc7SCole Faust			return protoreflect.ValueOfBool(bool(v)), nil
440*4947cdc7SCole Faust		}
441*4947cdc7SCole Faust
442*4947cdc7SCole Faust	case protoreflect.Fixed32Kind,
443*4947cdc7SCole Faust		protoreflect.Uint32Kind:
444*4947cdc7SCole Faust		// uint32
445*4947cdc7SCole Faust		if i, ok := v.(starlark.Int); ok {
446*4947cdc7SCole Faust			if u, ok := i.Uint64(); ok && uint64(uint32(u)) == u {
447*4947cdc7SCole Faust				return protoreflect.ValueOfUint32(uint32(u)), nil
448*4947cdc7SCole Faust			}
449*4947cdc7SCole Faust			return noValue, fmt.Errorf("invalid %s: %v", typeString(fdesc), i)
450*4947cdc7SCole Faust		}
451*4947cdc7SCole Faust
452*4947cdc7SCole Faust	case protoreflect.Int32Kind,
453*4947cdc7SCole Faust		protoreflect.Sfixed32Kind,
454*4947cdc7SCole Faust		protoreflect.Sint32Kind:
455*4947cdc7SCole Faust		// int32
456*4947cdc7SCole Faust		if i, ok := v.(starlark.Int); ok {
457*4947cdc7SCole Faust			if i, ok := i.Int64(); ok && int64(int32(i)) == i {
458*4947cdc7SCole Faust				return protoreflect.ValueOfInt32(int32(i)), nil
459*4947cdc7SCole Faust			}
460*4947cdc7SCole Faust			return noValue, fmt.Errorf("invalid %s: %v", typeString(fdesc), i)
461*4947cdc7SCole Faust		}
462*4947cdc7SCole Faust
463*4947cdc7SCole Faust	case protoreflect.Uint64Kind,
464*4947cdc7SCole Faust		protoreflect.Fixed64Kind:
465*4947cdc7SCole Faust		// uint64
466*4947cdc7SCole Faust		if i, ok := v.(starlark.Int); ok {
467*4947cdc7SCole Faust			if u, ok := i.Uint64(); ok {
468*4947cdc7SCole Faust				return protoreflect.ValueOfUint64(u), nil
469*4947cdc7SCole Faust			}
470*4947cdc7SCole Faust			return noValue, fmt.Errorf("invalid %s: %v", typeString(fdesc), i)
471*4947cdc7SCole Faust		}
472*4947cdc7SCole Faust
473*4947cdc7SCole Faust	case protoreflect.Int64Kind,
474*4947cdc7SCole Faust		protoreflect.Sfixed64Kind,
475*4947cdc7SCole Faust		protoreflect.Sint64Kind:
476*4947cdc7SCole Faust		// int64
477*4947cdc7SCole Faust		if i, ok := v.(starlark.Int); ok {
478*4947cdc7SCole Faust			if i, ok := i.Int64(); ok {
479*4947cdc7SCole Faust				return protoreflect.ValueOfInt64(i), nil
480*4947cdc7SCole Faust			}
481*4947cdc7SCole Faust			return noValue, fmt.Errorf("invalid %s: %v", typeString(fdesc), i)
482*4947cdc7SCole Faust		}
483*4947cdc7SCole Faust
484*4947cdc7SCole Faust	case protoreflect.StringKind:
485*4947cdc7SCole Faust		if s, ok := starlark.AsString(v); ok {
486*4947cdc7SCole Faust			return protoreflect.ValueOfString(s), nil
487*4947cdc7SCole Faust		} else if b, ok := v.(starlark.Bytes); ok {
488*4947cdc7SCole Faust			// TODO(adonovan): allow bytes for string? Not friendly to a Java port.
489*4947cdc7SCole Faust			return protoreflect.ValueOfBytes([]byte(b)), nil
490*4947cdc7SCole Faust		}
491*4947cdc7SCole Faust
492*4947cdc7SCole Faust	case protoreflect.BytesKind:
493*4947cdc7SCole Faust		if s, ok := starlark.AsString(v); ok {
494*4947cdc7SCole Faust			// TODO(adonovan): don't allow string for bytes: it's hostile to a Java port.
495*4947cdc7SCole Faust			// Instead provide b"..." literals in the core
496*4947cdc7SCole Faust			// and a bytes(str) conversion.
497*4947cdc7SCole Faust			return protoreflect.ValueOfBytes([]byte(s)), nil
498*4947cdc7SCole Faust		} else if b, ok := v.(starlark.Bytes); ok {
499*4947cdc7SCole Faust			return protoreflect.ValueOfBytes([]byte(b)), nil
500*4947cdc7SCole Faust		}
501*4947cdc7SCole Faust
502*4947cdc7SCole Faust	case protoreflect.DoubleKind:
503*4947cdc7SCole Faust		switch v := v.(type) {
504*4947cdc7SCole Faust		case starlark.Float:
505*4947cdc7SCole Faust			return protoreflect.ValueOfFloat64(float64(v)), nil
506*4947cdc7SCole Faust		case starlark.Int:
507*4947cdc7SCole Faust			return protoreflect.ValueOfFloat64(float64(v.Float())), nil
508*4947cdc7SCole Faust		}
509*4947cdc7SCole Faust
510*4947cdc7SCole Faust	case protoreflect.FloatKind:
511*4947cdc7SCole Faust		switch v := v.(type) {
512*4947cdc7SCole Faust		case starlark.Float:
513*4947cdc7SCole Faust			return protoreflect.ValueOfFloat32(float32(v)), nil
514*4947cdc7SCole Faust		case starlark.Int:
515*4947cdc7SCole Faust			return protoreflect.ValueOfFloat32(float32(v.Float())), nil
516*4947cdc7SCole Faust		}
517*4947cdc7SCole Faust
518*4947cdc7SCole Faust	case protoreflect.GroupKind,
519*4947cdc7SCole Faust		protoreflect.MessageKind:
520*4947cdc7SCole Faust		// Keep consistent with MessageDescriptor.CallInternal!
521*4947cdc7SCole Faust		desc := fdesc.Message()
522*4947cdc7SCole Faust		switch v := v.(type) {
523*4947cdc7SCole Faust		case *Message:
524*4947cdc7SCole Faust			if desc != v.desc() {
525*4947cdc7SCole Faust				return noValue, fmt.Errorf("got %s, want %s", v.desc().FullName(), desc.FullName())
526*4947cdc7SCole Faust			}
527*4947cdc7SCole Faust			return protoreflect.ValueOfMessage(v.msg), nil // alias it directly
528*4947cdc7SCole Faust
529*4947cdc7SCole Faust		case *starlark.Dict:
530*4947cdc7SCole Faust			dest := newMessage(desc)
531*4947cdc7SCole Faust			err := setFields(dest, v.Items())
532*4947cdc7SCole Faust			return protoreflect.ValueOfMessage(dest), err
533*4947cdc7SCole Faust		}
534*4947cdc7SCole Faust
535*4947cdc7SCole Faust	case protoreflect.EnumKind:
536*4947cdc7SCole Faust		enumval, err := enumValueOf(fdesc.Enum(), v)
537*4947cdc7SCole Faust		if err != nil {
538*4947cdc7SCole Faust			return noValue, err
539*4947cdc7SCole Faust		}
540*4947cdc7SCole Faust		return protoreflect.ValueOfEnum(enumval.Number()), nil
541*4947cdc7SCole Faust	}
542*4947cdc7SCole Faust
543*4947cdc7SCole Faust	return noValue, fmt.Errorf("got %s, want %s", v.Type(), typeString(fdesc))
544*4947cdc7SCole Faust}
545*4947cdc7SCole Faust
546*4947cdc7SCole Faustvar noValue protoreflect.Value
547*4947cdc7SCole Faust
548*4947cdc7SCole Faust// toStarlark returns a Starlark value for the value x of a message field.
549*4947cdc7SCole Faust// If the result is a repeated field or message,
550*4947cdc7SCole Faust// the result aliases the original and has the specified "frozenness" flag.
551*4947cdc7SCole Faust//
552*4947cdc7SCole Faust// fdesc is only used for the type, not other properties of the field.
553*4947cdc7SCole Faustfunc toStarlark(typ protoreflect.FieldDescriptor, x protoreflect.Value, frozen *bool) starlark.Value {
554*4947cdc7SCole Faust	if list, ok := x.Interface().(protoreflect.List); ok {
555*4947cdc7SCole Faust		return &RepeatedField{
556*4947cdc7SCole Faust			typ:    typ,
557*4947cdc7SCole Faust			list:   list,
558*4947cdc7SCole Faust			frozen: frozen,
559*4947cdc7SCole Faust		}
560*4947cdc7SCole Faust	}
561*4947cdc7SCole Faust	return toStarlark1(typ, x, frozen)
562*4947cdc7SCole Faust}
563*4947cdc7SCole Faust
564*4947cdc7SCole Faust// toStarlark1, for scalar (non-repeated) values only.
565*4947cdc7SCole Faustfunc toStarlark1(typ protoreflect.FieldDescriptor, x protoreflect.Value, frozen *bool) starlark.Value {
566*4947cdc7SCole Faust
567*4947cdc7SCole Faust	switch typ.Kind() {
568*4947cdc7SCole Faust	case protoreflect.BoolKind:
569*4947cdc7SCole Faust		return starlark.Bool(x.Bool())
570*4947cdc7SCole Faust
571*4947cdc7SCole Faust	case protoreflect.Fixed32Kind,
572*4947cdc7SCole Faust		protoreflect.Uint32Kind,
573*4947cdc7SCole Faust		protoreflect.Uint64Kind,
574*4947cdc7SCole Faust		protoreflect.Fixed64Kind:
575*4947cdc7SCole Faust		return starlark.MakeUint64(x.Uint())
576*4947cdc7SCole Faust
577*4947cdc7SCole Faust	case protoreflect.Int32Kind,
578*4947cdc7SCole Faust		protoreflect.Sfixed32Kind,
579*4947cdc7SCole Faust		protoreflect.Sint32Kind,
580*4947cdc7SCole Faust		protoreflect.Int64Kind,
581*4947cdc7SCole Faust		protoreflect.Sfixed64Kind,
582*4947cdc7SCole Faust		protoreflect.Sint64Kind:
583*4947cdc7SCole Faust		return starlark.MakeInt64(x.Int())
584*4947cdc7SCole Faust
585*4947cdc7SCole Faust	case protoreflect.StringKind:
586*4947cdc7SCole Faust		return starlark.String(x.String())
587*4947cdc7SCole Faust
588*4947cdc7SCole Faust	case protoreflect.BytesKind:
589*4947cdc7SCole Faust		return starlark.Bytes(x.Bytes())
590*4947cdc7SCole Faust
591*4947cdc7SCole Faust	case protoreflect.DoubleKind, protoreflect.FloatKind:
592*4947cdc7SCole Faust		return starlark.Float(x.Float())
593*4947cdc7SCole Faust
594*4947cdc7SCole Faust	case protoreflect.GroupKind, protoreflect.MessageKind:
595*4947cdc7SCole Faust		return &Message{
596*4947cdc7SCole Faust			msg:    x.Message(),
597*4947cdc7SCole Faust			frozen: frozen,
598*4947cdc7SCole Faust		}
599*4947cdc7SCole Faust
600*4947cdc7SCole Faust	case protoreflect.EnumKind:
601*4947cdc7SCole Faust		// Invariant: only EnumValueDescriptor may appear here.
602*4947cdc7SCole Faust		enumval := typ.Enum().Values().ByNumber(x.Enum())
603*4947cdc7SCole Faust		return EnumValueDescriptor{Desc: enumval}
604*4947cdc7SCole Faust	}
605*4947cdc7SCole Faust
606*4947cdc7SCole Faust	panic(fmt.Sprintf("got %T, want %s", x, typeString(typ)))
607*4947cdc7SCole Faust}
608*4947cdc7SCole Faust
609*4947cdc7SCole Faust// A Message is a Starlark value that wraps a protocol message.
610*4947cdc7SCole Faust//
611*4947cdc7SCole Faust// Two Messages are equivalent if and only if they are identical.
612*4947cdc7SCole Faust//
613*4947cdc7SCole Faust// When a Message value becomes frozen, a Starlark program may
614*4947cdc7SCole Faust// not modify the underlying protocol message, nor any Message
615*4947cdc7SCole Faust// or RepeatedField wrapper values derived from it.
616*4947cdc7SCole Fausttype Message struct {
617*4947cdc7SCole Faust	msg    protoreflect.Message // any concrete type is allowed
618*4947cdc7SCole Faust	frozen *bool                // shared by a group of related Message/RepeatedField wrappers
619*4947cdc7SCole Faust}
620*4947cdc7SCole Faust
621*4947cdc7SCole Faust// Message returns the wrapped message.
622*4947cdc7SCole Faustfunc (m *Message) Message() protoreflect.ProtoMessage { return m.msg.Interface() }
623*4947cdc7SCole Faust
624*4947cdc7SCole Faustfunc (m *Message) desc() protoreflect.MessageDescriptor { return m.msg.Descriptor() }
625*4947cdc7SCole Faust
626*4947cdc7SCole Faustvar _ starlark.HasSetField = (*Message)(nil)
627*4947cdc7SCole Faust
628*4947cdc7SCole Faust// Unmarshal parses the data as a binary protocol message of the specified type,
629*4947cdc7SCole Faust// and returns it as a new Starlark message value.
630*4947cdc7SCole Faustfunc Unmarshal(desc protoreflect.MessageDescriptor, data []byte) (*Message, error) {
631*4947cdc7SCole Faust	return unmarshalData(desc, data, true)
632*4947cdc7SCole Faust}
633*4947cdc7SCole Faust
634*4947cdc7SCole Faust// UnmarshalText parses the data as a text protocol message of the specified type,
635*4947cdc7SCole Faust// and returns it as a new Starlark message value.
636*4947cdc7SCole Faustfunc UnmarshalText(desc protoreflect.MessageDescriptor, data []byte) (*Message, error) {
637*4947cdc7SCole Faust	return unmarshalData(desc, data, false)
638*4947cdc7SCole Faust}
639*4947cdc7SCole Faust
640*4947cdc7SCole Faust// unmarshalData constructs a Starlark proto.Message by decoding binary or text data.
641*4947cdc7SCole Faustfunc unmarshalData(desc protoreflect.MessageDescriptor, data []byte, binary bool) (*Message, error) {
642*4947cdc7SCole Faust	m := &Message{
643*4947cdc7SCole Faust		msg:    newMessage(desc),
644*4947cdc7SCole Faust		frozen: new(bool),
645*4947cdc7SCole Faust	}
646*4947cdc7SCole Faust	var err error
647*4947cdc7SCole Faust	if binary {
648*4947cdc7SCole Faust		err = proto.Unmarshal(data, m.Message())
649*4947cdc7SCole Faust	} else {
650*4947cdc7SCole Faust		err = prototext.Unmarshal(data, m.Message())
651*4947cdc7SCole Faust	}
652*4947cdc7SCole Faust	if err != nil {
653*4947cdc7SCole Faust		return nil, fmt.Errorf("unmarshalling %s failed: %v", desc.FullName(), err)
654*4947cdc7SCole Faust	}
655*4947cdc7SCole Faust	return m, nil
656*4947cdc7SCole Faust}
657*4947cdc7SCole Faust
658*4947cdc7SCole Faustfunc (m *Message) String() string {
659*4947cdc7SCole Faust	buf := new(bytes.Buffer)
660*4947cdc7SCole Faust	buf.WriteString(string(m.desc().FullName()))
661*4947cdc7SCole Faust	buf.WriteByte('(')
662*4947cdc7SCole Faust
663*4947cdc7SCole Faust	// Sort fields (including extensions) by number.
664*4947cdc7SCole Faust	var fields []protoreflect.FieldDescriptor
665*4947cdc7SCole Faust	m.msg.Range(func(fdesc protoreflect.FieldDescriptor, v protoreflect.Value) bool {
666*4947cdc7SCole Faust		// TODO(adonovan): opt: save v in table too.
667*4947cdc7SCole Faust		fields = append(fields, fdesc)
668*4947cdc7SCole Faust		return true
669*4947cdc7SCole Faust	})
670*4947cdc7SCole Faust	sort.Slice(fields, func(i, j int) bool {
671*4947cdc7SCole Faust		return fields[i].Number() < fields[j].Number()
672*4947cdc7SCole Faust	})
673*4947cdc7SCole Faust
674*4947cdc7SCole Faust	for i, fdesc := range fields {
675*4947cdc7SCole Faust		if i > 0 {
676*4947cdc7SCole Faust			buf.WriteString(", ")
677*4947cdc7SCole Faust		}
678*4947cdc7SCole Faust		if fdesc.IsExtension() {
679*4947cdc7SCole Faust			// extension field: "[pkg.Msg.field]"
680*4947cdc7SCole Faust			buf.WriteString(string(fdesc.FullName()))
681*4947cdc7SCole Faust		} else if fdesc.Kind() != protoreflect.GroupKind {
682*4947cdc7SCole Faust			// ordinary field: "field"
683*4947cdc7SCole Faust			buf.WriteString(string(fdesc.Name()))
684*4947cdc7SCole Faust		} else {
685*4947cdc7SCole Faust			// group field: "MyGroup"
686*4947cdc7SCole Faust			//
687*4947cdc7SCole Faust			// The name of a group is the mangled version,
688*4947cdc7SCole Faust			// while the true name of a group is the message itself.
689*4947cdc7SCole Faust			// For example, for a group called "MyGroup",
690*4947cdc7SCole Faust			// the inlined message will be called "MyGroup",
691*4947cdc7SCole Faust			// but the field will be named "mygroup".
692*4947cdc7SCole Faust			// This rule complicates name logic everywhere.
693*4947cdc7SCole Faust			buf.WriteString(string(fdesc.Message().Name()))
694*4947cdc7SCole Faust		}
695*4947cdc7SCole Faust		buf.WriteString("=")
696*4947cdc7SCole Faust		writeString(buf, fdesc, m.msg.Get(fdesc))
697*4947cdc7SCole Faust	}
698*4947cdc7SCole Faust	buf.WriteByte(')')
699*4947cdc7SCole Faust	return buf.String()
700*4947cdc7SCole Faust}
701*4947cdc7SCole Faust
702*4947cdc7SCole Faustfunc (m *Message) Type() string                { return "proto.Message" }
703*4947cdc7SCole Faustfunc (m *Message) Truth() starlark.Bool        { return true }
704*4947cdc7SCole Faustfunc (m *Message) Freeze()                     { *m.frozen = true }
705*4947cdc7SCole Faustfunc (m *Message) Hash() (h uint32, err error) { return uint32(uintptr(unsafe.Pointer(m))), nil } // identity hash
706*4947cdc7SCole Faust
707*4947cdc7SCole Faust// Attr returns the value of this message's field of the specified name.
708*4947cdc7SCole Faust// Extension fields are not accessible this way as their names are not unique.
709*4947cdc7SCole Faustfunc (m *Message) Attr(name string) (starlark.Value, error) {
710*4947cdc7SCole Faust	// The name 'descriptor' is already effectively reserved
711*4947cdc7SCole Faust	// by the Go API for generated message types.
712*4947cdc7SCole Faust	if name == "descriptor" {
713*4947cdc7SCole Faust		return MessageDescriptor{Desc: m.desc()}, nil
714*4947cdc7SCole Faust	}
715*4947cdc7SCole Faust
716*4947cdc7SCole Faust	fdesc, err := fieldDesc(m.desc(), name)
717*4947cdc7SCole Faust	if err != nil {
718*4947cdc7SCole Faust		return nil, err
719*4947cdc7SCole Faust	}
720*4947cdc7SCole Faust	return m.getField(fdesc), nil
721*4947cdc7SCole Faust}
722*4947cdc7SCole Faust
723*4947cdc7SCole Faustfunc (m *Message) getField(fdesc protoreflect.FieldDescriptor) starlark.Value {
724*4947cdc7SCole Faust	if fdesc.IsExtension() {
725*4947cdc7SCole Faust		// See explanation in setField.
726*4947cdc7SCole Faust		fdesc = dynamicpb.NewExtensionType(fdesc).TypeDescriptor()
727*4947cdc7SCole Faust	}
728*4947cdc7SCole Faust
729*4947cdc7SCole Faust	if m.msg.Has(fdesc) {
730*4947cdc7SCole Faust		return toStarlark(fdesc, m.msg.Get(fdesc), m.frozen)
731*4947cdc7SCole Faust	}
732*4947cdc7SCole Faust	return defaultValue(fdesc)
733*4947cdc7SCole Faust}
734*4947cdc7SCole Faust
735*4947cdc7SCole Faust//go:linkname detrandDisable google.golang.org/protobuf/internal/detrand.Disable
736*4947cdc7SCole Faustfunc detrandDisable()
737*4947cdc7SCole Faust
738*4947cdc7SCole Faustfunc init() {
739*4947cdc7SCole Faust	// Nasty hack to disable the randomization of output that occurs in textproto.
740*4947cdc7SCole Faust	// TODO(adonovan): once go/proto-proposals/canonical-serialization
741*4947cdc7SCole Faust	// is resolved the need for the hack should go away. See also go/go-proto-stability.
742*4947cdc7SCole Faust	// If the proposal is rejected, we will need our own text-mode formatter.
743*4947cdc7SCole Faust	detrandDisable()
744*4947cdc7SCole Faust}
745*4947cdc7SCole Faust
746*4947cdc7SCole Faust// defaultValue returns the (frozen) default Starlark value for a given message field.
747*4947cdc7SCole Faustfunc defaultValue(fdesc protoreflect.FieldDescriptor) starlark.Value {
748*4947cdc7SCole Faust	frozen := true
749*4947cdc7SCole Faust
750*4947cdc7SCole Faust	// The default value of a repeated field is an empty list.
751*4947cdc7SCole Faust	if fdesc.IsList() {
752*4947cdc7SCole Faust		return &RepeatedField{typ: fdesc, list: emptyList{}, frozen: &frozen}
753*4947cdc7SCole Faust	}
754*4947cdc7SCole Faust
755*4947cdc7SCole Faust	// The zero value for a message type is an empty instance of that message.
756*4947cdc7SCole Faust	if desc := fdesc.Message(); desc != nil {
757*4947cdc7SCole Faust		return &Message{msg: newMessage(desc), frozen: &frozen}
758*4947cdc7SCole Faust	}
759*4947cdc7SCole Faust
760*4947cdc7SCole Faust	// Convert the default value, which is not necessarily zero, to Starlark.
761*4947cdc7SCole Faust	// The frozenness isn't used as the remaining types are all immutable.
762*4947cdc7SCole Faust	return toStarlark1(fdesc, fdesc.Default(), &frozen)
763*4947cdc7SCole Faust}
764*4947cdc7SCole Faust
765*4947cdc7SCole Faust// A frozen empty implementation of protoreflect.List.
766*4947cdc7SCole Fausttype emptyList struct{ protoreflect.List }
767*4947cdc7SCole Faust
768*4947cdc7SCole Faustfunc (emptyList) Len() int { return 0 }
769*4947cdc7SCole Faust
770*4947cdc7SCole Faust// newMessage returns a new empty instance of the message type described by desc.
771*4947cdc7SCole Faustfunc newMessage(desc protoreflect.MessageDescriptor) protoreflect.Message {
772*4947cdc7SCole Faust	// If desc refers to a built-in message,
773*4947cdc7SCole Faust	// use the more efficient generated type descriptor (a Go struct).
774*4947cdc7SCole Faust	mt, err := protoregistry.GlobalTypes.FindMessageByName(desc.FullName())
775*4947cdc7SCole Faust	if err == nil && mt.Descriptor() == desc {
776*4947cdc7SCole Faust		return mt.New()
777*4947cdc7SCole Faust	}
778*4947cdc7SCole Faust
779*4947cdc7SCole Faust	// For all others, use the generic dynamicpb representation.
780*4947cdc7SCole Faust	return dynamicpb.NewMessage(desc).ProtoReflect()
781*4947cdc7SCole Faust}
782*4947cdc7SCole Faust
783*4947cdc7SCole Faust// fieldDesc returns the descriptor for the named non-extension field.
784*4947cdc7SCole Faustfunc fieldDesc(desc protoreflect.MessageDescriptor, name string) (protoreflect.FieldDescriptor, error) {
785*4947cdc7SCole Faust	if fdesc := desc.Fields().ByName(protoreflect.Name(name)); fdesc != nil {
786*4947cdc7SCole Faust		return fdesc, nil
787*4947cdc7SCole Faust	}
788*4947cdc7SCole Faust	return nil, starlark.NoSuchAttrError(fmt.Sprintf("%s has no .%s field", desc.FullName(), name))
789*4947cdc7SCole Faust}
790*4947cdc7SCole Faust
791*4947cdc7SCole Faust// SetField updates a non-extension field of this message.
792*4947cdc7SCole Faust// It implements the HasSetField interface.
793*4947cdc7SCole Faustfunc (m *Message) SetField(name string, v starlark.Value) error {
794*4947cdc7SCole Faust	fdesc, err := fieldDesc(m.desc(), name)
795*4947cdc7SCole Faust	if err != nil {
796*4947cdc7SCole Faust		return err
797*4947cdc7SCole Faust	}
798*4947cdc7SCole Faust	if *m.frozen {
799*4947cdc7SCole Faust		return fmt.Errorf("cannot set .%s field of frozen %s message",
800*4947cdc7SCole Faust			name, m.desc().FullName())
801*4947cdc7SCole Faust	}
802*4947cdc7SCole Faust	return setField(m.msg, fdesc, v)
803*4947cdc7SCole Faust}
804*4947cdc7SCole Faust
805*4947cdc7SCole Faust// AttrNames returns the set of field names defined for this message.
806*4947cdc7SCole Faust// It satisfies the starlark.HasAttrs interface.
807*4947cdc7SCole Faustfunc (m *Message) AttrNames() []string {
808*4947cdc7SCole Faust	seen := make(map[string]bool)
809*4947cdc7SCole Faust
810*4947cdc7SCole Faust	// standard fields
811*4947cdc7SCole Faust	seen["descriptor"] = true
812*4947cdc7SCole Faust
813*4947cdc7SCole Faust	// non-extension fields
814*4947cdc7SCole Faust	fields := m.desc().Fields()
815*4947cdc7SCole Faust	for i := 0; i < fields.Len(); i++ {
816*4947cdc7SCole Faust		fdesc := fields.Get(i)
817*4947cdc7SCole Faust		if !fdesc.IsExtension() {
818*4947cdc7SCole Faust			seen[string(fdesc.Name())] = true
819*4947cdc7SCole Faust		}
820*4947cdc7SCole Faust	}
821*4947cdc7SCole Faust
822*4947cdc7SCole Faust	names := make([]string, 0, len(seen))
823*4947cdc7SCole Faust	for name := range seen {
824*4947cdc7SCole Faust		names = append(names, name)
825*4947cdc7SCole Faust	}
826*4947cdc7SCole Faust	sort.Strings(names)
827*4947cdc7SCole Faust	return names
828*4947cdc7SCole Faust}
829*4947cdc7SCole Faust
830*4947cdc7SCole Faust// typeString returns a user-friendly description of the type of a
831*4947cdc7SCole Faust// protocol message field (or element of a repeated field).
832*4947cdc7SCole Faustfunc typeString(fdesc protoreflect.FieldDescriptor) string {
833*4947cdc7SCole Faust	switch fdesc.Kind() {
834*4947cdc7SCole Faust	case protoreflect.GroupKind,
835*4947cdc7SCole Faust		protoreflect.MessageKind:
836*4947cdc7SCole Faust		return string(fdesc.Message().FullName())
837*4947cdc7SCole Faust
838*4947cdc7SCole Faust	case protoreflect.EnumKind:
839*4947cdc7SCole Faust		return string(fdesc.Enum().FullName())
840*4947cdc7SCole Faust
841*4947cdc7SCole Faust	default:
842*4947cdc7SCole Faust		return strings.ToLower(strings.TrimPrefix(fdesc.Kind().String(), "TYPE_"))
843*4947cdc7SCole Faust	}
844*4947cdc7SCole Faust}
845*4947cdc7SCole Faust
846*4947cdc7SCole Faust// A RepeatedField is a Starlark value that wraps a repeated field of a protocol message.
847*4947cdc7SCole Faust//
848*4947cdc7SCole Faust// An assignment to an element of a repeated field incurs a dynamic
849*4947cdc7SCole Faust// check that the new value has (or can be converted to) the correct
850*4947cdc7SCole Faust// type using conversions similar to those done when calling a
851*4947cdc7SCole Faust// MessageDescriptor to construct a message.
852*4947cdc7SCole Faust//
853*4947cdc7SCole Faust// TODO(adonovan): make RepeatedField implement starlark.Comparable.
854*4947cdc7SCole Faust// Should the comparison include type, or be defined on the elements alone?
855*4947cdc7SCole Fausttype RepeatedField struct {
856*4947cdc7SCole Faust	typ       protoreflect.FieldDescriptor // only for type information, not field name
857*4947cdc7SCole Faust	list      protoreflect.List
858*4947cdc7SCole Faust	frozen    *bool
859*4947cdc7SCole Faust	itercount int
860*4947cdc7SCole Faust}
861*4947cdc7SCole Faust
862*4947cdc7SCole Faustvar _ starlark.HasSetIndex = (*RepeatedField)(nil)
863*4947cdc7SCole Faust
864*4947cdc7SCole Faustfunc (rf *RepeatedField) Type() string {
865*4947cdc7SCole Faust	return fmt.Sprintf("proto.repeated<%s>", typeString(rf.typ))
866*4947cdc7SCole Faust}
867*4947cdc7SCole Faust
868*4947cdc7SCole Faustfunc (rf *RepeatedField) SetIndex(i int, v starlark.Value) error {
869*4947cdc7SCole Faust	if *rf.frozen {
870*4947cdc7SCole Faust		return fmt.Errorf("cannot insert value in frozen repeated field")
871*4947cdc7SCole Faust	}
872*4947cdc7SCole Faust	if rf.itercount > 0 {
873*4947cdc7SCole Faust		return fmt.Errorf("cannot insert value in repeated field with active iterators")
874*4947cdc7SCole Faust	}
875*4947cdc7SCole Faust	x, err := toProto(rf.typ, v)
876*4947cdc7SCole Faust	if err != nil {
877*4947cdc7SCole Faust		// The repeated field value cannot know which field it
878*4947cdc7SCole Faust		// belongs to---it might be shared by several of the
879*4947cdc7SCole Faust		// same type---so the error message is suboptimal.
880*4947cdc7SCole Faust		return fmt.Errorf("setting element of repeated field: %v", err)
881*4947cdc7SCole Faust	}
882*4947cdc7SCole Faust	rf.list.Set(i, x)
883*4947cdc7SCole Faust	return nil
884*4947cdc7SCole Faust}
885*4947cdc7SCole Faust
886*4947cdc7SCole Faustfunc (rf *RepeatedField) Freeze()               { *rf.frozen = true }
887*4947cdc7SCole Faustfunc (rf *RepeatedField) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable: %s", rf.Type()) }
888*4947cdc7SCole Faustfunc (rf *RepeatedField) Index(i int) starlark.Value {
889*4947cdc7SCole Faust	return toStarlark1(rf.typ, rf.list.Get(i), rf.frozen)
890*4947cdc7SCole Faust}
891*4947cdc7SCole Faustfunc (rf *RepeatedField) Iterate() starlark.Iterator {
892*4947cdc7SCole Faust	if !*rf.frozen {
893*4947cdc7SCole Faust		rf.itercount++
894*4947cdc7SCole Faust	}
895*4947cdc7SCole Faust	return &repeatedFieldIterator{rf, 0}
896*4947cdc7SCole Faust}
897*4947cdc7SCole Faustfunc (rf *RepeatedField) Len() int { return rf.list.Len() }
898*4947cdc7SCole Faustfunc (rf *RepeatedField) String() string {
899*4947cdc7SCole Faust	// We use list [...] notation even though it not exactly a list.
900*4947cdc7SCole Faust	buf := new(bytes.Buffer)
901*4947cdc7SCole Faust	buf.WriteByte('[')
902*4947cdc7SCole Faust	for i := 0; i < rf.list.Len(); i++ {
903*4947cdc7SCole Faust		if i > 0 {
904*4947cdc7SCole Faust			buf.WriteString(", ")
905*4947cdc7SCole Faust		}
906*4947cdc7SCole Faust		writeString(buf, rf.typ, rf.list.Get(i))
907*4947cdc7SCole Faust	}
908*4947cdc7SCole Faust	buf.WriteByte(']')
909*4947cdc7SCole Faust	return buf.String()
910*4947cdc7SCole Faust}
911*4947cdc7SCole Faustfunc (rf *RepeatedField) Truth() starlark.Bool { return rf.list.Len() > 0 }
912*4947cdc7SCole Faust
913*4947cdc7SCole Fausttype repeatedFieldIterator struct {
914*4947cdc7SCole Faust	rf *RepeatedField
915*4947cdc7SCole Faust	i  int
916*4947cdc7SCole Faust}
917*4947cdc7SCole Faust
918*4947cdc7SCole Faustfunc (it *repeatedFieldIterator) Next(p *starlark.Value) bool {
919*4947cdc7SCole Faust	if it.i < it.rf.Len() {
920*4947cdc7SCole Faust		*p = it.rf.Index(it.i)
921*4947cdc7SCole Faust		it.i++
922*4947cdc7SCole Faust		return true
923*4947cdc7SCole Faust	}
924*4947cdc7SCole Faust	return false
925*4947cdc7SCole Faust}
926*4947cdc7SCole Faust
927*4947cdc7SCole Faustfunc (it *repeatedFieldIterator) Done() {
928*4947cdc7SCole Faust	if !*it.rf.frozen {
929*4947cdc7SCole Faust		it.rf.itercount--
930*4947cdc7SCole Faust	}
931*4947cdc7SCole Faust}
932*4947cdc7SCole Faust
933*4947cdc7SCole Faustfunc writeString(buf *bytes.Buffer, fdesc protoreflect.FieldDescriptor, v protoreflect.Value) {
934*4947cdc7SCole Faust	// TODO(adonovan): opt: don't materialize the Starlark value.
935*4947cdc7SCole Faust	// TODO(adonovan): skip message type when printing submessages? {...}?
936*4947cdc7SCole Faust	var frozen bool // ignored
937*4947cdc7SCole Faust	x := toStarlark(fdesc, v, &frozen)
938*4947cdc7SCole Faust	buf.WriteString(x.String())
939*4947cdc7SCole Faust}
940*4947cdc7SCole Faust
941*4947cdc7SCole Faust// -------- descriptor values --------
942*4947cdc7SCole Faust
943*4947cdc7SCole Faust// A FileDescriptor is an immutable Starlark value that describes a
944*4947cdc7SCole Faust// .proto file.  It is a reference to a protoreflect.FileDescriptor.
945*4947cdc7SCole Faust// Two FileDescriptor values compare equal if and only if they refer to
946*4947cdc7SCole Faust// the same protoreflect.FileDescriptor.
947*4947cdc7SCole Faust//
948*4947cdc7SCole Faust// Its fields are the names of the message types (MessageDescriptor) and enum
949*4947cdc7SCole Faust// types (EnumDescriptor).
950*4947cdc7SCole Fausttype FileDescriptor struct {
951*4947cdc7SCole Faust	Desc protoreflect.FileDescriptor // TODO(adonovan): hide field, expose method?
952*4947cdc7SCole Faust}
953*4947cdc7SCole Faust
954*4947cdc7SCole Faustvar _ starlark.HasAttrs = FileDescriptor{}
955*4947cdc7SCole Faust
956*4947cdc7SCole Faustfunc (f FileDescriptor) String() string              { return string(f.Desc.Path()) }
957*4947cdc7SCole Faustfunc (f FileDescriptor) Type() string                { return "proto.FileDescriptor" }
958*4947cdc7SCole Faustfunc (f FileDescriptor) Truth() starlark.Bool        { return true }
959*4947cdc7SCole Faustfunc (f FileDescriptor) Freeze()                     {} // immutable
960*4947cdc7SCole Faustfunc (f FileDescriptor) Hash() (h uint32, err error) { return starlark.String(f.Desc.Path()).Hash() }
961*4947cdc7SCole Faustfunc (f FileDescriptor) Attr(name string) (starlark.Value, error) {
962*4947cdc7SCole Faust	if desc := f.Desc.Messages().ByName(protoreflect.Name(name)); desc != nil {
963*4947cdc7SCole Faust		return MessageDescriptor{Desc: desc}, nil
964*4947cdc7SCole Faust	}
965*4947cdc7SCole Faust	if desc := f.Desc.Extensions().ByName(protoreflect.Name(name)); desc != nil {
966*4947cdc7SCole Faust		return FieldDescriptor{desc}, nil
967*4947cdc7SCole Faust	}
968*4947cdc7SCole Faust	if enum := f.Desc.Enums().ByName(protoreflect.Name(name)); enum != nil {
969*4947cdc7SCole Faust		return EnumDescriptor{Desc: enum}, nil
970*4947cdc7SCole Faust	}
971*4947cdc7SCole Faust	return nil, nil
972*4947cdc7SCole Faust}
973*4947cdc7SCole Faustfunc (f FileDescriptor) AttrNames() []string {
974*4947cdc7SCole Faust	var names []string
975*4947cdc7SCole Faust	messages := f.Desc.Messages()
976*4947cdc7SCole Faust	for i, n := 0, messages.Len(); i < n; i++ {
977*4947cdc7SCole Faust		names = append(names, string(messages.Get(i).Name()))
978*4947cdc7SCole Faust	}
979*4947cdc7SCole Faust	extensions := f.Desc.Extensions()
980*4947cdc7SCole Faust	for i, n := 0, extensions.Len(); i < n; i++ {
981*4947cdc7SCole Faust		names = append(names, string(extensions.Get(i).Name()))
982*4947cdc7SCole Faust	}
983*4947cdc7SCole Faust	enums := f.Desc.Enums()
984*4947cdc7SCole Faust	for i, n := 0, enums.Len(); i < n; i++ {
985*4947cdc7SCole Faust		names = append(names, string(enums.Get(i).Name()))
986*4947cdc7SCole Faust	}
987*4947cdc7SCole Faust	sort.Strings(names)
988*4947cdc7SCole Faust	return names
989*4947cdc7SCole Faust}
990*4947cdc7SCole Faust
991*4947cdc7SCole Faust// A MessageDescriptor is an immutable Starlark value that describes a protocol
992*4947cdc7SCole Faust// message type.
993*4947cdc7SCole Faust//
994*4947cdc7SCole Faust// A MessageDescriptor value contains a reference to a protoreflect.MessageDescriptor.
995*4947cdc7SCole Faust// Two MessageDescriptor values compare equal if and only if they refer to the
996*4947cdc7SCole Faust// same protoreflect.MessageDescriptor.
997*4947cdc7SCole Faust//
998*4947cdc7SCole Faust// The fields of a MessageDescriptor value are the names of any message types
999*4947cdc7SCole Faust// (MessageDescriptor), fields or extension fields (FieldDescriptor),
1000*4947cdc7SCole Faust// and enum types (EnumDescriptor) nested within the declaration of this message type.
1001*4947cdc7SCole Fausttype MessageDescriptor struct {
1002*4947cdc7SCole Faust	Desc protoreflect.MessageDescriptor
1003*4947cdc7SCole Faust}
1004*4947cdc7SCole Faust
1005*4947cdc7SCole Faustvar (
1006*4947cdc7SCole Faust	_ starlark.Callable = MessageDescriptor{}
1007*4947cdc7SCole Faust	_ starlark.HasAttrs = MessageDescriptor{}
1008*4947cdc7SCole Faust)
1009*4947cdc7SCole Faust
1010*4947cdc7SCole Faustfunc (d MessageDescriptor) String() string       { return string(d.Desc.FullName()) }
1011*4947cdc7SCole Faustfunc (d MessageDescriptor) Type() string         { return "proto.MessageDescriptor" }
1012*4947cdc7SCole Faustfunc (d MessageDescriptor) Truth() starlark.Bool { return true }
1013*4947cdc7SCole Faustfunc (d MessageDescriptor) Freeze()              {} // immutable
1014*4947cdc7SCole Faustfunc (d MessageDescriptor) Hash() (h uint32, err error) {
1015*4947cdc7SCole Faust	return starlark.String(d.Desc.FullName()).Hash()
1016*4947cdc7SCole Faust}
1017*4947cdc7SCole Faustfunc (d MessageDescriptor) Attr(name string) (starlark.Value, error) {
1018*4947cdc7SCole Faust	if desc := d.Desc.Messages().ByName(protoreflect.Name(name)); desc != nil {
1019*4947cdc7SCole Faust		return MessageDescriptor{desc}, nil
1020*4947cdc7SCole Faust	}
1021*4947cdc7SCole Faust	if desc := d.Desc.Extensions().ByName(protoreflect.Name(name)); desc != nil {
1022*4947cdc7SCole Faust		return FieldDescriptor{desc}, nil
1023*4947cdc7SCole Faust	}
1024*4947cdc7SCole Faust	if desc := d.Desc.Fields().ByName(protoreflect.Name(name)); desc != nil {
1025*4947cdc7SCole Faust		return FieldDescriptor{desc}, nil
1026*4947cdc7SCole Faust	}
1027*4947cdc7SCole Faust	if desc := d.Desc.Enums().ByName(protoreflect.Name(name)); desc != nil {
1028*4947cdc7SCole Faust		return EnumDescriptor{desc}, nil
1029*4947cdc7SCole Faust	}
1030*4947cdc7SCole Faust	return nil, nil
1031*4947cdc7SCole Faust}
1032*4947cdc7SCole Faustfunc (d MessageDescriptor) AttrNames() []string {
1033*4947cdc7SCole Faust	var names []string
1034*4947cdc7SCole Faust	messages := d.Desc.Messages()
1035*4947cdc7SCole Faust	for i, n := 0, messages.Len(); i < n; i++ {
1036*4947cdc7SCole Faust		names = append(names, string(messages.Get(i).Name()))
1037*4947cdc7SCole Faust	}
1038*4947cdc7SCole Faust	enums := d.Desc.Enums()
1039*4947cdc7SCole Faust	for i, n := 0, enums.Len(); i < n; i++ {
1040*4947cdc7SCole Faust		names = append(names, string(enums.Get(i).Name()))
1041*4947cdc7SCole Faust	}
1042*4947cdc7SCole Faust	sort.Strings(names)
1043*4947cdc7SCole Faust	return names
1044*4947cdc7SCole Faust}
1045*4947cdc7SCole Faustfunc (d MessageDescriptor) Name() string { return string(d.Desc.Name()) } // for Callable
1046*4947cdc7SCole Faust
1047*4947cdc7SCole Faust// A FieldDescriptor is an immutable Starlark value that describes
1048*4947cdc7SCole Faust// a field (possibly an extension field) of protocol message.
1049*4947cdc7SCole Faust//
1050*4947cdc7SCole Faust// A FieldDescriptor value contains a reference to a protoreflect.FieldDescriptor.
1051*4947cdc7SCole Faust// Two FieldDescriptor values compare equal if and only if they refer to the
1052*4947cdc7SCole Faust// same protoreflect.FieldDescriptor.
1053*4947cdc7SCole Faust//
1054*4947cdc7SCole Faust// The primary use for FieldDescriptors is to access extension fields of a message.
1055*4947cdc7SCole Faust//
1056*4947cdc7SCole Faust// A FieldDescriptor value has not attributes.
1057*4947cdc7SCole Faust// TODO(adonovan): expose metadata fields (e.g. name, type).
1058*4947cdc7SCole Fausttype FieldDescriptor struct {
1059*4947cdc7SCole Faust	Desc protoreflect.FieldDescriptor
1060*4947cdc7SCole Faust}
1061*4947cdc7SCole Faust
1062*4947cdc7SCole Faustvar (
1063*4947cdc7SCole Faust	_ starlark.HasAttrs = FieldDescriptor{}
1064*4947cdc7SCole Faust)
1065*4947cdc7SCole Faust
1066*4947cdc7SCole Faustfunc (d FieldDescriptor) String() string       { return string(d.Desc.FullName()) }
1067*4947cdc7SCole Faustfunc (d FieldDescriptor) Type() string         { return "proto.FieldDescriptor" }
1068*4947cdc7SCole Faustfunc (d FieldDescriptor) Truth() starlark.Bool { return true }
1069*4947cdc7SCole Faustfunc (d FieldDescriptor) Freeze()              {} // immutable
1070*4947cdc7SCole Faustfunc (d FieldDescriptor) Hash() (h uint32, err error) {
1071*4947cdc7SCole Faust	return starlark.String(d.Desc.FullName()).Hash()
1072*4947cdc7SCole Faust}
1073*4947cdc7SCole Faustfunc (d FieldDescriptor) Attr(name string) (starlark.Value, error) {
1074*4947cdc7SCole Faust	// TODO(adonovan): expose metadata fields of Desc?
1075*4947cdc7SCole Faust	return nil, nil
1076*4947cdc7SCole Faust}
1077*4947cdc7SCole Faustfunc (d FieldDescriptor) AttrNames() []string {
1078*4947cdc7SCole Faust	var names []string
1079*4947cdc7SCole Faust	// TODO(adonovan): expose metadata fields of Desc?
1080*4947cdc7SCole Faust	sort.Strings(names)
1081*4947cdc7SCole Faust	return names
1082*4947cdc7SCole Faust}
1083*4947cdc7SCole Faust
1084*4947cdc7SCole Faust// An EnumDescriptor is an immutable Starlark value that describes an
1085*4947cdc7SCole Faust// protocol enum type.
1086*4947cdc7SCole Faust//
1087*4947cdc7SCole Faust// An EnumDescriptor contains a reference to a protoreflect.EnumDescriptor.
1088*4947cdc7SCole Faust// Two EnumDescriptor values compare equal if and only if they
1089*4947cdc7SCole Faust// refer to the same protoreflect.EnumDescriptor.
1090*4947cdc7SCole Faust//
1091*4947cdc7SCole Faust// An EnumDescriptor may be called like a function.  It converts its
1092*4947cdc7SCole Faust// sole argument, which must be an int, string, or EnumValueDescriptor,
1093*4947cdc7SCole Faust// to an EnumValueDescriptor.
1094*4947cdc7SCole Faust//
1095*4947cdc7SCole Faust// The fields of an EnumDescriptor value are the values of the
1096*4947cdc7SCole Faust// enumeration, each of type EnumValueDescriptor.
1097*4947cdc7SCole Fausttype EnumDescriptor struct {
1098*4947cdc7SCole Faust	Desc protoreflect.EnumDescriptor
1099*4947cdc7SCole Faust}
1100*4947cdc7SCole Faust
1101*4947cdc7SCole Faustvar (
1102*4947cdc7SCole Faust	_ starlark.HasAttrs = EnumDescriptor{}
1103*4947cdc7SCole Faust	_ starlark.Callable = EnumDescriptor{}
1104*4947cdc7SCole Faust)
1105*4947cdc7SCole Faust
1106*4947cdc7SCole Faustfunc (e EnumDescriptor) String() string              { return string(e.Desc.FullName()) }
1107*4947cdc7SCole Faustfunc (e EnumDescriptor) Type() string                { return "proto.EnumDescriptor" }
1108*4947cdc7SCole Faustfunc (e EnumDescriptor) Truth() starlark.Bool        { return true }
1109*4947cdc7SCole Faustfunc (e EnumDescriptor) Freeze()                     {}                // immutable
1110*4947cdc7SCole Faustfunc (e EnumDescriptor) Hash() (h uint32, err error) { return 0, nil } // TODO(adonovan): number?
1111*4947cdc7SCole Faustfunc (e EnumDescriptor) Attr(name string) (starlark.Value, error) {
1112*4947cdc7SCole Faust	if v := e.Desc.Values().ByName(protoreflect.Name(name)); v != nil {
1113*4947cdc7SCole Faust		return EnumValueDescriptor{v}, nil
1114*4947cdc7SCole Faust	}
1115*4947cdc7SCole Faust	return nil, nil
1116*4947cdc7SCole Faust}
1117*4947cdc7SCole Faustfunc (e EnumDescriptor) AttrNames() []string {
1118*4947cdc7SCole Faust	var names []string
1119*4947cdc7SCole Faust	values := e.Desc.Values()
1120*4947cdc7SCole Faust	for i, n := 0, values.Len(); i < n; i++ {
1121*4947cdc7SCole Faust		names = append(names, string(values.Get(i).Name()))
1122*4947cdc7SCole Faust	}
1123*4947cdc7SCole Faust	sort.Strings(names)
1124*4947cdc7SCole Faust	return names
1125*4947cdc7SCole Faust}
1126*4947cdc7SCole Faustfunc (e EnumDescriptor) Name() string { return string(e.Desc.Name()) } // for Callable
1127*4947cdc7SCole Faust
1128*4947cdc7SCole Faust// The Call method implements the starlark.Callable interface.
1129*4947cdc7SCole Faust// A call to an enum descriptor converts its argument to a value of that enum type.
1130*4947cdc7SCole Faustfunc (e EnumDescriptor) CallInternal(_ *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
1131*4947cdc7SCole Faust	var x starlark.Value
1132*4947cdc7SCole Faust	if err := starlark.UnpackPositionalArgs(string(e.Desc.Name()), args, kwargs, 1, &x); err != nil {
1133*4947cdc7SCole Faust		return nil, err
1134*4947cdc7SCole Faust	}
1135*4947cdc7SCole Faust	v, err := enumValueOf(e.Desc, x)
1136*4947cdc7SCole Faust	if err != nil {
1137*4947cdc7SCole Faust		return nil, fmt.Errorf("%s: %v", e.Desc.Name(), err)
1138*4947cdc7SCole Faust	}
1139*4947cdc7SCole Faust	return EnumValueDescriptor{Desc: v}, nil
1140*4947cdc7SCole Faust}
1141*4947cdc7SCole Faust
1142*4947cdc7SCole Faust// enumValueOf converts an int, string, or enum value to a value of the specified enum type.
1143*4947cdc7SCole Faustfunc enumValueOf(enum protoreflect.EnumDescriptor, x starlark.Value) (protoreflect.EnumValueDescriptor, error) {
1144*4947cdc7SCole Faust	switch x := x.(type) {
1145*4947cdc7SCole Faust	case starlark.Int:
1146*4947cdc7SCole Faust		i, err := starlark.AsInt32(x)
1147*4947cdc7SCole Faust		if err != nil {
1148*4947cdc7SCole Faust			return nil, fmt.Errorf("invalid number %s for %s enum", x, enum.Name())
1149*4947cdc7SCole Faust		}
1150*4947cdc7SCole Faust		desc := enum.Values().ByNumber(protoreflect.EnumNumber(i))
1151*4947cdc7SCole Faust		if desc == nil {
1152*4947cdc7SCole Faust			return nil, fmt.Errorf("invalid number %d for %s enum", i, enum.Name())
1153*4947cdc7SCole Faust		}
1154*4947cdc7SCole Faust		return desc, nil
1155*4947cdc7SCole Faust
1156*4947cdc7SCole Faust	case starlark.String:
1157*4947cdc7SCole Faust		name := protoreflect.Name(x)
1158*4947cdc7SCole Faust		desc := enum.Values().ByName(name)
1159*4947cdc7SCole Faust		if desc == nil {
1160*4947cdc7SCole Faust			return nil, fmt.Errorf("invalid name %q for %s enum", name, enum.Name())
1161*4947cdc7SCole Faust		}
1162*4947cdc7SCole Faust		return desc, nil
1163*4947cdc7SCole Faust
1164*4947cdc7SCole Faust	case EnumValueDescriptor:
1165*4947cdc7SCole Faust		if parent := x.Desc.Parent(); parent != enum {
1166*4947cdc7SCole Faust			return nil, fmt.Errorf("invalid value %s.%s for %s enum",
1167*4947cdc7SCole Faust				parent.Name(), x.Desc.Name(), enum.Name())
1168*4947cdc7SCole Faust		}
1169*4947cdc7SCole Faust		return x.Desc, nil
1170*4947cdc7SCole Faust	}
1171*4947cdc7SCole Faust
1172*4947cdc7SCole Faust	return nil, fmt.Errorf("cannot convert %s to %s enum", x.Type(), enum.Name())
1173*4947cdc7SCole Faust}
1174*4947cdc7SCole Faust
1175*4947cdc7SCole Faust// An EnumValueDescriptor is an immutable Starlark value that represents one value of an enumeration.
1176*4947cdc7SCole Faust//
1177*4947cdc7SCole Faust// An EnumValueDescriptor contains a reference to a protoreflect.EnumValueDescriptor.
1178*4947cdc7SCole Faust// Two EnumValueDescriptor values compare equal if and only if they
1179*4947cdc7SCole Faust// refer to the same protoreflect.EnumValueDescriptor.
1180*4947cdc7SCole Faust//
1181*4947cdc7SCole Faust// An EnumValueDescriptor has the following fields:
1182*4947cdc7SCole Faust//
1183*4947cdc7SCole Faust//      index   -- int, index of this value within the enum sequence
1184*4947cdc7SCole Faust//      name    -- string, name of this enum value
1185*4947cdc7SCole Faust//      number  -- int, numeric value of this enum value
1186*4947cdc7SCole Faust//      type    -- EnumDescriptor, the enum type to which this value belongs
1187*4947cdc7SCole Faust//
1188*4947cdc7SCole Fausttype EnumValueDescriptor struct {
1189*4947cdc7SCole Faust	Desc protoreflect.EnumValueDescriptor
1190*4947cdc7SCole Faust}
1191*4947cdc7SCole Faust
1192*4947cdc7SCole Faustvar (
1193*4947cdc7SCole Faust	_ starlark.HasAttrs   = EnumValueDescriptor{}
1194*4947cdc7SCole Faust	_ starlark.Comparable = EnumValueDescriptor{}
1195*4947cdc7SCole Faust)
1196*4947cdc7SCole Faust
1197*4947cdc7SCole Faustfunc (e EnumValueDescriptor) String() string {
1198*4947cdc7SCole Faust	enum := e.Desc.Parent()
1199*4947cdc7SCole Faust	return string(enum.Name() + "." + e.Desc.Name()) // "Enum.EnumValue"
1200*4947cdc7SCole Faust}
1201*4947cdc7SCole Faustfunc (e EnumValueDescriptor) Type() string                { return "proto.EnumValueDescriptor" }
1202*4947cdc7SCole Faustfunc (e EnumValueDescriptor) Truth() starlark.Bool        { return true }
1203*4947cdc7SCole Faustfunc (e EnumValueDescriptor) Freeze()                     {} // immutable
1204*4947cdc7SCole Faustfunc (e EnumValueDescriptor) Hash() (h uint32, err error) { return uint32(e.Desc.Number()), nil }
1205*4947cdc7SCole Faustfunc (e EnumValueDescriptor) AttrNames() []string {
1206*4947cdc7SCole Faust	return []string{"index", "name", "number", "type"}
1207*4947cdc7SCole Faust}
1208*4947cdc7SCole Faustfunc (e EnumValueDescriptor) Attr(name string) (starlark.Value, error) {
1209*4947cdc7SCole Faust	switch name {
1210*4947cdc7SCole Faust	case "index":
1211*4947cdc7SCole Faust		return starlark.MakeInt(e.Desc.Index()), nil
1212*4947cdc7SCole Faust	case "name":
1213*4947cdc7SCole Faust		return starlark.String(e.Desc.Name()), nil
1214*4947cdc7SCole Faust	case "number":
1215*4947cdc7SCole Faust		return starlark.MakeInt(int(e.Desc.Number())), nil
1216*4947cdc7SCole Faust	case "type":
1217*4947cdc7SCole Faust		enum := e.Desc.Parent()
1218*4947cdc7SCole Faust		return EnumDescriptor{Desc: enum.(protoreflect.EnumDescriptor)}, nil
1219*4947cdc7SCole Faust	}
1220*4947cdc7SCole Faust	return nil, nil
1221*4947cdc7SCole Faust}
1222*4947cdc7SCole Faustfunc (x EnumValueDescriptor) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
1223*4947cdc7SCole Faust	y := y_.(EnumValueDescriptor)
1224*4947cdc7SCole Faust	switch op {
1225*4947cdc7SCole Faust	case syntax.EQL:
1226*4947cdc7SCole Faust		return x.Desc == y.Desc, nil
1227*4947cdc7SCole Faust	case syntax.NEQ:
1228*4947cdc7SCole Faust		return x.Desc != y.Desc, nil
1229*4947cdc7SCole Faust	default:
1230*4947cdc7SCole Faust		return false, fmt.Errorf("%s %s %s not implemented", x.Type(), op, y_.Type())
1231*4947cdc7SCole Faust	}
1232*4947cdc7SCole Faust}
1233