xref: /aosp_15_r20/external/tensorflow/tensorflow/go/tensor.go (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package tensorflow
18
19/*
20#include <stdlib.h>
21#include <string.h>
22#include "tensorflow/c/c_api.h"
23
24void toNewTString(_GoString_ gstr, TF_TString *tstr) {
25    TF_TString_Init(tstr);
26    TF_TString_Copy(tstr, _GoStringPtr(gstr), _GoStringLen(gstr));
27}
28*/
29import "C"
30
31import (
32	"bytes"
33	"fmt"
34	"io"
35	"math/bits"
36	"reflect"
37	"runtime"
38	"unsafe"
39)
40
41// DataType holds the type for a scalar value.  E.g., one slot in a tensor.
42type DataType C.TF_DataType
43
44// Types of scalar values in the TensorFlow type system.
45const (
46	Float      DataType = C.TF_FLOAT
47	Double     DataType = C.TF_DOUBLE
48	Int32      DataType = C.TF_INT32
49	Uint32     DataType = C.TF_UINT32
50	Uint8      DataType = C.TF_UINT8
51	Int16      DataType = C.TF_INT16
52	Int8       DataType = C.TF_INT8
53	String     DataType = C.TF_STRING
54	Complex64  DataType = C.TF_COMPLEX64
55	Complex    DataType = C.TF_COMPLEX
56	Int64      DataType = C.TF_INT64
57	Uint64     DataType = C.TF_UINT64
58	Bool       DataType = C.TF_BOOL
59	Qint8      DataType = C.TF_QINT8
60	Quint8     DataType = C.TF_QUINT8
61	Qint32     DataType = C.TF_QINT32
62	Bfloat16   DataType = C.TF_BFLOAT16
63	Qint16     DataType = C.TF_QINT16
64	Quint16    DataType = C.TF_QUINT16
65	Uint16     DataType = C.TF_UINT16
66	Complex128 DataType = C.TF_COMPLEX128
67	Half       DataType = C.TF_HALF
68)
69
70// Tensor holds a multi-dimensional array of elements of a single data type.
71type Tensor struct {
72	c     *C.TF_Tensor
73	shape []int64
74}
75
76// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
77// slices, and arrays. Every element of a slice must have the same length so
78// that the resulting Tensor has a valid shape.
79func NewTensor(value interface{}) (*Tensor, error) {
80	val := reflect.ValueOf(value)
81	shape, dataType, err := shapeAndDataTypeOf(val)
82	if err != nil {
83		return nil, err
84	}
85	nflattened := numElements(shape)
86	nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened)
87	if dataType == String {
88		nbytes = uintptr(nflattened) * C.sizeof_TF_TString
89	}
90	var shapePtr *C.int64_t
91	if len(shape) > 0 {
92		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
93	}
94	t := &Tensor{
95		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
96		shape: shape,
97	}
98
99	raw := tensorData(t.c)
100
101	runtime.SetFinalizer(t, (*Tensor).finalize)
102
103	buf := bytes.NewBuffer(raw[:0:len(raw)])
104
105	if isAllArray(val.Type()) {
106		// We have arrays all the way down, or just primitive types. We can
107		// just copy the memory in as it is all contiguous.
108		if _, err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
109			return nil, err
110		}
111	} else {
112		// When there are slices involved the memory for each leaf slice may
113		// not be contiguous with the others or in the order we might
114		// expect, so we need to work our way down to each slice of
115		// primitives and copy them individually
116		if _, err := encodeTensorWithSlices(buf, val, shape); err != nil {
117			return nil, err
118		}
119	}
120
121	if uintptr(buf.Len()) != nbytes {
122		return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
123	}
124	return t, nil
125}
126
127// isAllArray returns true if type is a primitive type or an array of primitive
128// types or an array of ... etc.. When this is true the data we want is
129// contiguous in RAM.
130func isAllArray(typ reflect.Type) bool {
131	switch typ.Kind() {
132	case reflect.String:
133		return false
134	case reflect.Slice:
135		return false
136	case reflect.Array:
137		return isAllArray(typ.Elem())
138	default:
139		// We know the type is slices/arrays of slices/arrays of primitive types.
140		return true
141	}
142}
143
144// eface defines what an interface type actually is: a pointer to type
145// information about the encapsulated type and a pointer to the encapsulated
146// value.
147type eface struct {
148	rtype unsafe.Pointer
149	data  unsafe.Pointer
150}
151
152// unpackEFace gives us an effient way to get us a pointer to the value carried
153// in an interface. If you wrap a pointer type in an interface then the pointer
154// is directly stored in the interface struct. If you wrap a value type in an
155// interface then the compiler copies the value into a newly allocated piece of
156// memory and stores a pointer to that memory in the interface. So we're
157// guaranteed to get a pointer. Go reflection doesn't expose the pointer to
158// value types straightforwardly as it doesn't want you to think you have a
159// reference to the original value. But we just want a pointer to make it
160// efficient to read the value, so cheating like this should be safe and
161// reasonable.
162func unpackEFace(obj interface{}) *eface {
163	return (*eface)(unsafe.Pointer(&obj))
164}
165
166// ReadTensor constructs a Tensor with the provided type and shape from the
167// serialized tensor contents in r.
168//
169// See also WriteContentsTo.
170func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
171	if err := isTensorSerializable(dataType); err != nil {
172		return nil, err
173	}
174
175	var shapePtr *C.int64_t
176	if len(shape) > 0 {
177		for _, dim := range shape {
178			if dim < 0 {
179				return nil, fmt.Errorf("all shape dimentions should be non-negative: %v", shape)
180			}
181		}
182		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
183	}
184
185	nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape))
186	t := &Tensor{
187		c:     C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
188		shape: shape,
189	}
190	runtime.SetFinalizer(t, (*Tensor).finalize)
191	raw := tensorData(t.c)
192	if _, err := io.ReadFull(r, raw); err != nil {
193		return nil, err
194	}
195	return t, nil
196}
197
198// newTensorFromC takes ownership of c and returns the owning Tensor.
199func newTensorFromC(c *C.TF_Tensor) *Tensor {
200	var shape []int64
201	if ndims := int(C.TF_NumDims(c)); ndims > 0 {
202		shape = make([]int64, ndims)
203	}
204	for i := range shape {
205		shape[i] = int64(C.TF_Dim(c, C.int(i)))
206	}
207	t := &Tensor{c: c, shape: shape}
208	runtime.SetFinalizer(t, (*Tensor).finalize)
209	return t
210}
211
212func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) }
213
214// DataType returns the scalar datatype of the Tensor.
215func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
216
217// Shape returns the shape of the Tensor.
218func (t *Tensor) Shape() []int64 { return t.shape }
219
220// Reshape  updates tensor's shape in place if this is possible or returns an error otherwise.
221func (t *Tensor) Reshape(newShape []int64) error {
222	oldShapeSize := numElements(t.shape)
223	newShapeSize := numElements(newShape)
224
225	if oldShapeSize != newShapeSize {
226		return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize)
227	}
228
229	if len(newShape) == 0 {
230		return nil
231	}
232
233	var shapePtr *C.int64_t
234	shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0]))
235
236	status := newStatus()
237	C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c)
238
239	if err := status.Err(); err != nil {
240		return err
241	}
242	t.shape = newShape
243	return nil
244}
245
246// Value converts the Tensor to a Go value. For now, not all Tensor types are
247// supported, and this function may panic if it encounters an unsupported
248// DataType.
249//
250// The type of the output depends on the Tensor type and dimensions.
251// For example:
252// Tensor(int64, 0): int64
253// Tensor(float64, 3): [][][]float64
254func (t *Tensor) Value() interface{} {
255	raw := tensorData(t.c)
256	shape := t.Shape()
257	dt := t.DataType()
258	return decodeTensor(raw, shape, dt).Interface()
259}
260
261func decodeTensor(raw []byte, shape []int64, dt DataType) reflect.Value {
262	// Create a 1-dimensional slice of the base large enough for the data and
263	// copy the data in.
264	n := int(numElements(shape))
265
266	var (
267		slice reflect.Value
268		typ   reflect.Type
269	)
270	if dt == String {
271		strs, err := decodeOneDimString(raw, n)
272		if err != nil {
273			panic(bug("unable to decode string with shape %v: %v", shape, err))
274		}
275		slice = reflect.ValueOf(strs)
276		typ = slice.Type()
277	} else {
278		typ = typeForDataType(dt)
279		l := n * int(typ.Size())
280		typ = reflect.SliceOf(typ)
281		slice = reflect.MakeSlice(typ, n, n)
282		baseBytes := *(*[]byte)(unsafe.Pointer(&sliceHeader{
283			Data: unsafe.Pointer(slice.Pointer()),
284			Len:  l,
285			Cap:  l,
286		}))
287		copy(baseBytes, raw)
288	}
289
290	// Now we have the data in place in the base slice we can add the
291	// dimensions. We want to walk backwards through the shape. If the shape is
292	// length 1 or 0 then we're already done.
293	if len(shape) == 0 {
294		return slice.Index(0)
295	}
296	if len(shape) == 1 {
297		return slice
298	}
299	// We have a special case if the tensor has no data. Our backing slice is
300	// empty, but we still want to create slices following the shape. In this
301	// case only the final part of the shape will be 0 and we want to recalculate
302	// n at this point ignoring that 0.
303	// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
304	// want 6 zero length slices to group as follows.
305	// {{} {}} {{} {}} {{} {}}
306	if n == 0 {
307		n = int(numElements(shape[:len(shape)-1]))
308	}
309	for i := len(shape) - 2; i >= 0; i-- {
310		underlyingSize := typ.Elem().Size()
311		typ = reflect.SliceOf(typ)
312		subsliceLen := int(shape[i+1])
313		if subsliceLen != 0 {
314			n = n / subsliceLen
315		}
316		// Just using reflection it is difficult to avoid unnecessary
317		// allocations while setting up the sub-slices as the Slice function on
318		// a slice Value allocates. So we end up doing pointer arithmetic!
319		// Pointer() on a slice gives us access to the data backing the slice.
320		// We insert slice headers directly into this data.
321		data := unsafe.Pointer(slice.Pointer())
322		nextSlice := reflect.MakeSlice(typ, n, n)
323
324		for j := 0; j < n; j++ {
325			// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
326			setSliceInSlice(nextSlice, j, sliceHeader{
327				Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
328				Len:  subsliceLen,
329				Cap:  subsliceLen,
330			})
331		}
332
333		slice = nextSlice
334	}
335	return slice
336}
337
338// setSliceInSlice sets slice[index] = content.
339func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
340	const sliceSize = unsafe.Sizeof(sliceHeader{})
341	// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
342	// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
343	*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
344}
345
346// decodeOneDimString decodes a string tensor into a one-dimensional []string.
347func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
348	strs := make([]string, nStrings)
349	tstrs := (*(*[]C.TF_TString)(unsafe.Pointer(&raw)))[:nStrings]
350
351	for i, tstr := range tstrs {
352		dst := C.TF_TString_GetDataPointer(&tstr)
353		dstLen := C.TF_TString_GetSize(&tstr)
354
355		strs[i] = C.GoStringN(dst, C.int(dstLen))
356	}
357
358	return strs, nil
359}
360
361// WriteContentsTo writes the serialized contents of t to w.
362//
363// Returns the number of bytes written. See ReadTensor for
364// reconstructing a Tensor from the serialized form.
365//
366// WARNING: WriteContentsTo is not comprehensive and will fail
367// if t.DataType() is non-numeric (e.g., String). See
368// https://github.com/tensorflow/tensorflow/issues/6003.
369func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
370	if err := isTensorSerializable(t.DataType()); err != nil {
371		return 0, err
372	}
373	return io.Copy(w, bytes.NewReader(tensorData(t.c)))
374}
375
376func tensorData(c *C.TF_Tensor) []byte {
377	// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
378	cbytes := C.TF_TensorData(c)
379	if cbytes == nil {
380		return nil
381	}
382	length := int(C.TF_TensorByteSize(c))
383	var slice []byte
384	if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 {
385		slice = (*[1<<50 - 1]byte)(unsafe.Pointer(cbytes))[:length:length]
386	} else {
387		slice = (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length]
388	}
389	return slice
390}
391
392var types = []struct {
393	typ      reflect.Type
394	dataType C.TF_DataType
395}{
396	{reflect.TypeOf(float32(0)), C.TF_FLOAT},
397	{reflect.TypeOf(float64(0)), C.TF_DOUBLE},
398	{reflect.TypeOf(int32(0)), C.TF_INT32},
399	{reflect.TypeOf(uint32(0)), C.TF_UINT32},
400	{reflect.TypeOf(uint8(0)), C.TF_UINT8},
401	{reflect.TypeOf(int16(0)), C.TF_INT16},
402	{reflect.TypeOf(int8(0)), C.TF_INT8},
403	{reflect.TypeOf(""), C.TF_STRING},
404	{reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64},
405	{reflect.TypeOf(int64(0)), C.TF_INT64},
406	{reflect.TypeOf(uint64(0)), C.TF_UINT64},
407	{reflect.TypeOf(false), C.TF_BOOL},
408	{reflect.TypeOf(uint16(0)), C.TF_UINT16},
409	{reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128},
410	// TODO(apassos): support DT_RESOURCE representation in go.
411	// TODO(keveman): support DT_VARIANT representation in go.
412}
413
414// shapeAndDataTypeOf returns the data type and shape of the Tensor
415// corresponding to a Go type.
416func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err error) {
417	typ := val.Type()
418	for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice {
419		shape = append(shape, int64(val.Len()))
420		// If slice elements are slices, verify that all of them have the same size.
421		// Go's type system makes that guarantee for arrays.
422		if val.Len() > 0 {
423			if val.Type().Elem().Kind() == reflect.Slice {
424				expected := val.Index(0).Len()
425				for i := 1; i < val.Len(); i++ {
426					if val.Index(i).Len() != expected {
427						return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected)
428					}
429				}
430			}
431			val = val.Index(0)
432		}
433		typ = typ.Elem()
434	}
435	for _, t := range types {
436		if typ.Kind() == t.typ.Kind() {
437			return shape, DataType(t.dataType), nil
438		}
439	}
440	return shape, dt, fmt.Errorf("unsupported type %v", typ)
441}
442
443func typeForDataType(dt DataType) reflect.Type {
444	for _, t := range types {
445		if dt == DataType(t.dataType) {
446			return t.typ
447		}
448	}
449	panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
450}
451
452// TypeOf converts from a DataType and Shape to the equivalent Go type.
453func TypeOf(dt DataType, shape []int64) reflect.Type {
454	ret := typeForDataType(dt)
455	for range shape {
456		ret = reflect.SliceOf(ret)
457	}
458	return ret
459}
460
461func numElements(shape []int64) int64 {
462	n := int64(1)
463	for _, d := range shape {
464		n *= d
465	}
466	return n
467}
468
469// sizeVarUint determines how many bytes it would take to encode the int v as
470// an unsigned varint
471func sizeVarUint(v uint64) int {
472	if v < 0x80 {
473		return 1
474	}
475	bits := bits.Len64(v)
476	return (bits + 6) / 7
477}
478
479// encodeTensorWithSlices writes v to the specified buffer using the format specified in
480// c_api.h. Use stringEncoder for String tensors.
481func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) (int, error) {
482	// If current dimension is a slice, verify that it has the expected size
483	// Go's type system makes that guarantee for arrays.
484	if v.Kind() == reflect.Slice {
485		expected := int(shape[0])
486		if v.Len() != expected {
487			return 0, fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
488		}
489	} else if v.Kind() == reflect.String {
490		s := v.Interface().(string)
491		var tstr C.TF_TString
492		C.toNewTString(s, &tstr)
493		ptr := unsafe.Pointer(&tstr)
494		return copyPtr(w, ptr, C.sizeof_TF_TString)
495	} else if v.Kind() != reflect.Array {
496		return 0, fmt.Errorf("unsupported type %v", v.Type())
497	}
498
499	// Once we have just a single dimension we can just copy the data
500	if len(shape) == 1 && v.Len() > 0 && v.Index(0).Kind() != reflect.String {
501		elt := v.Index(0)
502		if !elt.CanAddr() {
503			panic("cannot take address")
504		}
505		ptr := unsafe.Pointer(elt.Addr().Pointer())
506		return copyPtr(w, ptr, v.Len()*int(elt.Type().Size()))
507	}
508
509	n := 0
510	subShape := shape[1:]
511	for i := 0; i < v.Len(); i++ {
512		j, err := encodeTensorWithSlices(w, v.Index(i), subShape)
513		if err != nil {
514			return n + j, err
515		}
516		n += j
517	}
518
519	return n, nil
520}
521
522// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
523// this is not inspected by the garbage collector
524type sliceHeader struct {
525	Data unsafe.Pointer
526	Len  int
527	Cap  int
528}
529
530// copyPtr copies the backing data for a slice or array directly into w. Note
531// we don't need to worry about byte ordering because we want the natural byte
532// order for the machine we're running on.
533func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) (int, error) {
534	// Convert our slice header into a []byte so we can call w.Write
535	b := *(*[]byte)(unsafe.Pointer(&sliceHeader{
536		Data: ptr,
537		Len:  l,
538		Cap:  l,
539	}))
540	return w.Write(b)
541}
542
543func bug(format string, args ...interface{}) error {
544	return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
545}
546
547func isTensorSerializable(dataType DataType) error {
548	// For numeric types, the serialized Tensor matches the in-memory
549	// representation.  See the implementation of Tensor::AsProtoContent in
550	// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
551	//
552	// The more appropriate way to be in sync with Tensor::AsProtoContent
553	// would be to have the TensorFlow C library export functions for
554	// serialization and deserialization of Tensors.  Till then capitalize
555	// on knowledge of the implementation for numeric types.
556	switch dataType {
557	case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
558		return nil
559	default:
560		return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
561	}
562}
563