1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package tensorflow 18 19/* 20#include <stdlib.h> 21#include <string.h> 22#include "tensorflow/c/c_api.h" 23 24void toNewTString(_GoString_ gstr, TF_TString *tstr) { 25 TF_TString_Init(tstr); 26 TF_TString_Copy(tstr, _GoStringPtr(gstr), _GoStringLen(gstr)); 27} 28*/ 29import "C" 30 31import ( 32 "bytes" 33 "fmt" 34 "io" 35 "math/bits" 36 "reflect" 37 "runtime" 38 "unsafe" 39) 40 41// DataType holds the type for a scalar value. E.g., one slot in a tensor. 42type DataType C.TF_DataType 43 44// Types of scalar values in the TensorFlow type system. 45const ( 46 Float DataType = C.TF_FLOAT 47 Double DataType = C.TF_DOUBLE 48 Int32 DataType = C.TF_INT32 49 Uint32 DataType = C.TF_UINT32 50 Uint8 DataType = C.TF_UINT8 51 Int16 DataType = C.TF_INT16 52 Int8 DataType = C.TF_INT8 53 String DataType = C.TF_STRING 54 Complex64 DataType = C.TF_COMPLEX64 55 Complex DataType = C.TF_COMPLEX 56 Int64 DataType = C.TF_INT64 57 Uint64 DataType = C.TF_UINT64 58 Bool DataType = C.TF_BOOL 59 Qint8 DataType = C.TF_QINT8 60 Quint8 DataType = C.TF_QUINT8 61 Qint32 DataType = C.TF_QINT32 62 Bfloat16 DataType = C.TF_BFLOAT16 63 Qint16 DataType = C.TF_QINT16 64 Quint16 DataType = C.TF_QUINT16 65 Uint16 DataType = C.TF_UINT16 66 Complex128 DataType = C.TF_COMPLEX128 67 Half DataType = C.TF_HALF 68) 69 70// Tensor holds a multi-dimensional array of elements of a single data type. 71type Tensor struct { 72 c *C.TF_Tensor 73 shape []int64 74} 75 76// NewTensor converts from a Go value to a Tensor. Valid values are scalars, 77// slices, and arrays. Every element of a slice must have the same length so 78// that the resulting Tensor has a valid shape. 79func NewTensor(value interface{}) (*Tensor, error) { 80 val := reflect.ValueOf(value) 81 shape, dataType, err := shapeAndDataTypeOf(val) 82 if err != nil { 83 return nil, err 84 } 85 nflattened := numElements(shape) 86 nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened) 87 if dataType == String { 88 nbytes = uintptr(nflattened) * C.sizeof_TF_TString 89 } 90 var shapePtr *C.int64_t 91 if len(shape) > 0 { 92 shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 93 } 94 t := &Tensor{ 95 c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 96 shape: shape, 97 } 98 99 raw := tensorData(t.c) 100 101 runtime.SetFinalizer(t, (*Tensor).finalize) 102 103 buf := bytes.NewBuffer(raw[:0:len(raw)]) 104 105 if isAllArray(val.Type()) { 106 // We have arrays all the way down, or just primitive types. We can 107 // just copy the memory in as it is all contiguous. 108 if _, err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil { 109 return nil, err 110 } 111 } else { 112 // When there are slices involved the memory for each leaf slice may 113 // not be contiguous with the others or in the order we might 114 // expect, so we need to work our way down to each slice of 115 // primitives and copy them individually 116 if _, err := encodeTensorWithSlices(buf, val, shape); err != nil { 117 return nil, err 118 } 119 } 120 121 if uintptr(buf.Len()) != nbytes { 122 return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) 123 } 124 return t, nil 125} 126 127// isAllArray returns true if type is a primitive type or an array of primitive 128// types or an array of ... etc.. When this is true the data we want is 129// contiguous in RAM. 130func isAllArray(typ reflect.Type) bool { 131 switch typ.Kind() { 132 case reflect.String: 133 return false 134 case reflect.Slice: 135 return false 136 case reflect.Array: 137 return isAllArray(typ.Elem()) 138 default: 139 // We know the type is slices/arrays of slices/arrays of primitive types. 140 return true 141 } 142} 143 144// eface defines what an interface type actually is: a pointer to type 145// information about the encapsulated type and a pointer to the encapsulated 146// value. 147type eface struct { 148 rtype unsafe.Pointer 149 data unsafe.Pointer 150} 151 152// unpackEFace gives us an effient way to get us a pointer to the value carried 153// in an interface. If you wrap a pointer type in an interface then the pointer 154// is directly stored in the interface struct. If you wrap a value type in an 155// interface then the compiler copies the value into a newly allocated piece of 156// memory and stores a pointer to that memory in the interface. So we're 157// guaranteed to get a pointer. Go reflection doesn't expose the pointer to 158// value types straightforwardly as it doesn't want you to think you have a 159// reference to the original value. But we just want a pointer to make it 160// efficient to read the value, so cheating like this should be safe and 161// reasonable. 162func unpackEFace(obj interface{}) *eface { 163 return (*eface)(unsafe.Pointer(&obj)) 164} 165 166// ReadTensor constructs a Tensor with the provided type and shape from the 167// serialized tensor contents in r. 168// 169// See also WriteContentsTo. 170func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { 171 if err := isTensorSerializable(dataType); err != nil { 172 return nil, err 173 } 174 175 var shapePtr *C.int64_t 176 if len(shape) > 0 { 177 for _, dim := range shape { 178 if dim < 0 { 179 return nil, fmt.Errorf("all shape dimentions should be non-negative: %v", shape) 180 } 181 } 182 shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) 183 } 184 185 nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape)) 186 t := &Tensor{ 187 c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), 188 shape: shape, 189 } 190 runtime.SetFinalizer(t, (*Tensor).finalize) 191 raw := tensorData(t.c) 192 if _, err := io.ReadFull(r, raw); err != nil { 193 return nil, err 194 } 195 return t, nil 196} 197 198// newTensorFromC takes ownership of c and returns the owning Tensor. 199func newTensorFromC(c *C.TF_Tensor) *Tensor { 200 var shape []int64 201 if ndims := int(C.TF_NumDims(c)); ndims > 0 { 202 shape = make([]int64, ndims) 203 } 204 for i := range shape { 205 shape[i] = int64(C.TF_Dim(c, C.int(i))) 206 } 207 t := &Tensor{c: c, shape: shape} 208 runtime.SetFinalizer(t, (*Tensor).finalize) 209 return t 210} 211 212func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) } 213 214// DataType returns the scalar datatype of the Tensor. 215func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } 216 217// Shape returns the shape of the Tensor. 218func (t *Tensor) Shape() []int64 { return t.shape } 219 220// Reshape updates tensor's shape in place if this is possible or returns an error otherwise. 221func (t *Tensor) Reshape(newShape []int64) error { 222 oldShapeSize := numElements(t.shape) 223 newShapeSize := numElements(newShape) 224 225 if oldShapeSize != newShapeSize { 226 return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize) 227 } 228 229 if len(newShape) == 0 { 230 return nil 231 } 232 233 var shapePtr *C.int64_t 234 shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0])) 235 236 status := newStatus() 237 C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c) 238 239 if err := status.Err(); err != nil { 240 return err 241 } 242 t.shape = newShape 243 return nil 244} 245 246// Value converts the Tensor to a Go value. For now, not all Tensor types are 247// supported, and this function may panic if it encounters an unsupported 248// DataType. 249// 250// The type of the output depends on the Tensor type and dimensions. 251// For example: 252// Tensor(int64, 0): int64 253// Tensor(float64, 3): [][][]float64 254func (t *Tensor) Value() interface{} { 255 raw := tensorData(t.c) 256 shape := t.Shape() 257 dt := t.DataType() 258 return decodeTensor(raw, shape, dt).Interface() 259} 260 261func decodeTensor(raw []byte, shape []int64, dt DataType) reflect.Value { 262 // Create a 1-dimensional slice of the base large enough for the data and 263 // copy the data in. 264 n := int(numElements(shape)) 265 266 var ( 267 slice reflect.Value 268 typ reflect.Type 269 ) 270 if dt == String { 271 strs, err := decodeOneDimString(raw, n) 272 if err != nil { 273 panic(bug("unable to decode string with shape %v: %v", shape, err)) 274 } 275 slice = reflect.ValueOf(strs) 276 typ = slice.Type() 277 } else { 278 typ = typeForDataType(dt) 279 l := n * int(typ.Size()) 280 typ = reflect.SliceOf(typ) 281 slice = reflect.MakeSlice(typ, n, n) 282 baseBytes := *(*[]byte)(unsafe.Pointer(&sliceHeader{ 283 Data: unsafe.Pointer(slice.Pointer()), 284 Len: l, 285 Cap: l, 286 })) 287 copy(baseBytes, raw) 288 } 289 290 // Now we have the data in place in the base slice we can add the 291 // dimensions. We want to walk backwards through the shape. If the shape is 292 // length 1 or 0 then we're already done. 293 if len(shape) == 0 { 294 return slice.Index(0) 295 } 296 if len(shape) == 1 { 297 return slice 298 } 299 // We have a special case if the tensor has no data. Our backing slice is 300 // empty, but we still want to create slices following the shape. In this 301 // case only the final part of the shape will be 0 and we want to recalculate 302 // n at this point ignoring that 0. 303 // For example if our shape is 3 * 2 * 0 then n will be zero, but we still 304 // want 6 zero length slices to group as follows. 305 // {{} {}} {{} {}} {{} {}} 306 if n == 0 { 307 n = int(numElements(shape[:len(shape)-1])) 308 } 309 for i := len(shape) - 2; i >= 0; i-- { 310 underlyingSize := typ.Elem().Size() 311 typ = reflect.SliceOf(typ) 312 subsliceLen := int(shape[i+1]) 313 if subsliceLen != 0 { 314 n = n / subsliceLen 315 } 316 // Just using reflection it is difficult to avoid unnecessary 317 // allocations while setting up the sub-slices as the Slice function on 318 // a slice Value allocates. So we end up doing pointer arithmetic! 319 // Pointer() on a slice gives us access to the data backing the slice. 320 // We insert slice headers directly into this data. 321 data := unsafe.Pointer(slice.Pointer()) 322 nextSlice := reflect.MakeSlice(typ, n, n) 323 324 for j := 0; j < n; j++ { 325 // This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen] 326 setSliceInSlice(nextSlice, j, sliceHeader{ 327 Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)), 328 Len: subsliceLen, 329 Cap: subsliceLen, 330 }) 331 } 332 333 slice = nextSlice 334 } 335 return slice 336} 337 338// setSliceInSlice sets slice[index] = content. 339func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) { 340 const sliceSize = unsafe.Sizeof(sliceHeader{}) 341 // We must cast slice.Pointer to uninptr & back again to avoid GC issues. 342 // See https://github.com/google/go-cmp/issues/167#issuecomment-546093202 343 *(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content 344} 345 346// decodeOneDimString decodes a string tensor into a one-dimensional []string. 347func decodeOneDimString(raw []byte, nStrings int) ([]string, error) { 348 strs := make([]string, nStrings) 349 tstrs := (*(*[]C.TF_TString)(unsafe.Pointer(&raw)))[:nStrings] 350 351 for i, tstr := range tstrs { 352 dst := C.TF_TString_GetDataPointer(&tstr) 353 dstLen := C.TF_TString_GetSize(&tstr) 354 355 strs[i] = C.GoStringN(dst, C.int(dstLen)) 356 } 357 358 return strs, nil 359} 360 361// WriteContentsTo writes the serialized contents of t to w. 362// 363// Returns the number of bytes written. See ReadTensor for 364// reconstructing a Tensor from the serialized form. 365// 366// WARNING: WriteContentsTo is not comprehensive and will fail 367// if t.DataType() is non-numeric (e.g., String). See 368// https://github.com/tensorflow/tensorflow/issues/6003. 369func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { 370 if err := isTensorSerializable(t.DataType()); err != nil { 371 return 0, err 372 } 373 return io.Copy(w, bytes.NewReader(tensorData(t.c))) 374} 375 376func tensorData(c *C.TF_Tensor) []byte { 377 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 378 cbytes := C.TF_TensorData(c) 379 if cbytes == nil { 380 return nil 381 } 382 length := int(C.TF_TensorByteSize(c)) 383 var slice []byte 384 if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 { 385 slice = (*[1<<50 - 1]byte)(unsafe.Pointer(cbytes))[:length:length] 386 } else { 387 slice = (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length] 388 } 389 return slice 390} 391 392var types = []struct { 393 typ reflect.Type 394 dataType C.TF_DataType 395}{ 396 {reflect.TypeOf(float32(0)), C.TF_FLOAT}, 397 {reflect.TypeOf(float64(0)), C.TF_DOUBLE}, 398 {reflect.TypeOf(int32(0)), C.TF_INT32}, 399 {reflect.TypeOf(uint32(0)), C.TF_UINT32}, 400 {reflect.TypeOf(uint8(0)), C.TF_UINT8}, 401 {reflect.TypeOf(int16(0)), C.TF_INT16}, 402 {reflect.TypeOf(int8(0)), C.TF_INT8}, 403 {reflect.TypeOf(""), C.TF_STRING}, 404 {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64}, 405 {reflect.TypeOf(int64(0)), C.TF_INT64}, 406 {reflect.TypeOf(uint64(0)), C.TF_UINT64}, 407 {reflect.TypeOf(false), C.TF_BOOL}, 408 {reflect.TypeOf(uint16(0)), C.TF_UINT16}, 409 {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128}, 410 // TODO(apassos): support DT_RESOURCE representation in go. 411 // TODO(keveman): support DT_VARIANT representation in go. 412} 413 414// shapeAndDataTypeOf returns the data type and shape of the Tensor 415// corresponding to a Go type. 416func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err error) { 417 typ := val.Type() 418 for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice { 419 shape = append(shape, int64(val.Len())) 420 // If slice elements are slices, verify that all of them have the same size. 421 // Go's type system makes that guarantee for arrays. 422 if val.Len() > 0 { 423 if val.Type().Elem().Kind() == reflect.Slice { 424 expected := val.Index(0).Len() 425 for i := 1; i < val.Len(); i++ { 426 if val.Index(i).Len() != expected { 427 return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected) 428 } 429 } 430 } 431 val = val.Index(0) 432 } 433 typ = typ.Elem() 434 } 435 for _, t := range types { 436 if typ.Kind() == t.typ.Kind() { 437 return shape, DataType(t.dataType), nil 438 } 439 } 440 return shape, dt, fmt.Errorf("unsupported type %v", typ) 441} 442 443func typeForDataType(dt DataType) reflect.Type { 444 for _, t := range types { 445 if dt == DataType(t.dataType) { 446 return t.typ 447 } 448 } 449 panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt)) 450} 451 452// TypeOf converts from a DataType and Shape to the equivalent Go type. 453func TypeOf(dt DataType, shape []int64) reflect.Type { 454 ret := typeForDataType(dt) 455 for range shape { 456 ret = reflect.SliceOf(ret) 457 } 458 return ret 459} 460 461func numElements(shape []int64) int64 { 462 n := int64(1) 463 for _, d := range shape { 464 n *= d 465 } 466 return n 467} 468 469// sizeVarUint determines how many bytes it would take to encode the int v as 470// an unsigned varint 471func sizeVarUint(v uint64) int { 472 if v < 0x80 { 473 return 1 474 } 475 bits := bits.Len64(v) 476 return (bits + 6) / 7 477} 478 479// encodeTensorWithSlices writes v to the specified buffer using the format specified in 480// c_api.h. Use stringEncoder for String tensors. 481func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) (int, error) { 482 // If current dimension is a slice, verify that it has the expected size 483 // Go's type system makes that guarantee for arrays. 484 if v.Kind() == reflect.Slice { 485 expected := int(shape[0]) 486 if v.Len() != expected { 487 return 0, fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected) 488 } 489 } else if v.Kind() == reflect.String { 490 s := v.Interface().(string) 491 var tstr C.TF_TString 492 C.toNewTString(s, &tstr) 493 ptr := unsafe.Pointer(&tstr) 494 return copyPtr(w, ptr, C.sizeof_TF_TString) 495 } else if v.Kind() != reflect.Array { 496 return 0, fmt.Errorf("unsupported type %v", v.Type()) 497 } 498 499 // Once we have just a single dimension we can just copy the data 500 if len(shape) == 1 && v.Len() > 0 && v.Index(0).Kind() != reflect.String { 501 elt := v.Index(0) 502 if !elt.CanAddr() { 503 panic("cannot take address") 504 } 505 ptr := unsafe.Pointer(elt.Addr().Pointer()) 506 return copyPtr(w, ptr, v.Len()*int(elt.Type().Size())) 507 } 508 509 n := 0 510 subShape := shape[1:] 511 for i := 0; i < v.Len(); i++ { 512 j, err := encodeTensorWithSlices(w, v.Index(i), subShape) 513 if err != nil { 514 return n + j, err 515 } 516 n += j 517 } 518 519 return n, nil 520} 521 522// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and 523// this is not inspected by the garbage collector 524type sliceHeader struct { 525 Data unsafe.Pointer 526 Len int 527 Cap int 528} 529 530// copyPtr copies the backing data for a slice or array directly into w. Note 531// we don't need to worry about byte ordering because we want the natural byte 532// order for the machine we're running on. 533func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) (int, error) { 534 // Convert our slice header into a []byte so we can call w.Write 535 b := *(*[]byte)(unsafe.Pointer(&sliceHeader{ 536 Data: ptr, 537 Len: l, 538 Cap: l, 539 })) 540 return w.Write(b) 541} 542 543func bug(format string, args ...interface{}) error { 544 return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) 545} 546 547func isTensorSerializable(dataType DataType) error { 548 // For numeric types, the serialized Tensor matches the in-memory 549 // representation. See the implementation of Tensor::AsProtoContent in 550 // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc 551 // 552 // The more appropriate way to be in sync with Tensor::AsProtoContent 553 // would be to have the TensorFlow C library export functions for 554 // serialization and deserialization of Tensors. Till then capitalize 555 // on knowledge of the implementation for numeric types. 556 switch dataType { 557 case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half: 558 return nil 559 default: 560 return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) 561 } 562} 563