1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.lang.reflect.Array; 19 import java.nio.Buffer; 20 import java.nio.BufferOverflowException; 21 import java.nio.ByteBuffer; 22 import java.nio.ByteOrder; 23 import java.nio.DoubleBuffer; 24 import java.nio.FloatBuffer; 25 import java.nio.IntBuffer; 26 import java.nio.LongBuffer; 27 import java.util.Arrays; 28 import java.util.HashMap; 29 30 /** 31 * A statically typed multi-dimensional array whose elements are of a type described by T. 32 * 33 * <p>Instances of a Tensor are <b>not</b> thread-safe. 34 * 35 * <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by 36 * invoking the {@link #close()} method when the object is no longer needed. For example, using a 37 * try-with-resources block: 38 * 39 * <pre>{@code 40 * try (Tensor t = Tensor.create(...)) { 41 * doSomethingWith(t); 42 * } 43 * }</pre> 44 */ 45 public final class Tensor<T> implements AutoCloseable { 46 47 /** 48 * Creates a Tensor from a Java object. 49 * 50 * <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types. Not all 51 * Java objects can be converted to a {@code Tensor}. In particular, the argument {@code obj} must 52 * be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of 53 * one of those primitives. The argument {@code type} specifies how to interpret the first 54 * argument as a TensorFlow type. For example: 55 * 56 * <pre>{@code 57 * // Valid: A 64-bit integer scalar. 58 * Tensor<Long> s = Tensor.create(42L, Long.class); 59 * 60 * // Valid: A 3x2 matrix of floats. 61 * float[][] matrix = new float[3][2]; 62 * Tensor<Float> m = Tensor.create(matrix, Float.class); 63 * 64 * // Invalid: Will throw an IllegalArgumentException as an arbitrary Object 65 * // does not fit into the TensorFlow type system. 66 * Tensor<?> o = Tensor.create(new Object()) 67 * 68 * // Invalid: Will throw an IllegalArgumentException since there are 69 * // a differing number of elements in each row of this 2-D array. 70 * int[][] twoD = new int[2][]; 71 * twoD[0] = new int[1]; 72 * twoD[1] = new int[2]; 73 * Tensor<Integer> x = Tensor.create(twoD, Integer.class); 74 * }</pre> 75 * 76 * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can 77 * be initialized from arrays of {@code byte[]} elements. For example: 78 * 79 * <pre>{@code 80 * // Valid: A String tensor. 81 * Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class); 82 * 83 * // Java Strings will need to be encoded into a byte-sequence. 84 * String mystring = "foo"; 85 * Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class); 86 * 87 * // Valid: Matrix of String tensors. 88 * // Each element might have a different length. 89 * byte[][][] matrix = new byte[2][2][]; 90 * matrix[0][0] = "this".getBytes("UTF-8"); 91 * matrix[0][1] = "is".getBytes("UTF-8"); 92 * matrix[1][0] = "a".getBytes("UTF-8"); 93 * matrix[1][1] = "matrix".getBytes("UTF-8"); 94 * Tensor<String> m = Tensor.create(matrix, String.class); 95 * }</pre> 96 * 97 * @param obj The object to convert to a {@code Tensor<T>}. Note that whether it is compatible 98 * with the type T is not checked by the type system. For type-safe creation of tensors, use 99 * {@link Tensors}. 100 * @param type The class object representing the type T. 101 * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type 102 * system. 103 */ 104 @SuppressWarnings("unchecked") create(Object obj, Class<T> type)105 public static <T> Tensor<T> create(Object obj, Class<T> type) { 106 DataType dtype = DataType.fromClass(type); 107 if (!objectCompatWithType(obj, dtype)) { 108 throw new IllegalArgumentException( 109 "DataType of object does not match T (expected " 110 + dtype 111 + ", got " 112 + dataTypeOf(obj) 113 + ")"); 114 } 115 return (Tensor<T>) create(obj, dtype); 116 } 117 118 /** 119 * Creates a tensor from an object whose class is inspected to figure out what the underlying data 120 * type should be. 121 * 122 * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type 123 * system. 124 */ create(Object obj)125 public static Tensor<?> create(Object obj) { 126 return create(obj, dataTypeOf(obj)); 127 } 128 129 /** 130 * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T} 131 * to match {@code type}, but this condition is not checked. 132 * 133 * @param obj the object supplying the tensor data. 134 * @param dtype the data type of the tensor to create. It must be compatible with the run-time 135 * type of the object. 136 * @return the new tensor 137 */ create(Object obj, DataType dtype)138 private static Tensor<?> create(Object obj, DataType dtype) { 139 @SuppressWarnings("rawtypes") 140 Tensor<?> t = new Tensor(dtype); 141 t.shapeCopy = new long[numDimensions(obj, dtype)]; 142 fillShape(obj, 0, t.shapeCopy); 143 long nativeHandle; 144 if (t.dtype != DataType.STRING) { 145 int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); 146 nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); 147 setValue(nativeHandle, obj); 148 } else if (t.shapeCopy.length != 0) { 149 nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); 150 } else { 151 nativeHandle = allocateScalarBytes((byte[]) obj); 152 } 153 t.nativeRef = new NativeReference(nativeHandle); 154 return t; 155 } 156 157 /** 158 * Create a {@link Integer} Tensor with data from the given buffer. 159 * 160 * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its 161 * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a 162 * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this 163 * method. 164 * 165 * @param shape the tensor shape. 166 * @param data a buffer containing the tensor data. 167 * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer 168 */ create(long[] shape, IntBuffer data)169 public static Tensor<Integer> create(long[] shape, IntBuffer data) { 170 Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining()); 171 t.buffer().asIntBuffer().put(data); 172 return t; 173 } 174 175 /** 176 * Create a {@link Float} Tensor with data from the given buffer. 177 * 178 * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its 179 * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a 180 * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this 181 * method. 182 * 183 * @param shape the tensor shape. 184 * @param data a buffer containing the tensor data. 185 * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer 186 */ create(long[] shape, FloatBuffer data)187 public static Tensor<Float> create(long[] shape, FloatBuffer data) { 188 Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining()); 189 t.buffer().asFloatBuffer().put(data); 190 return t; 191 } 192 193 /** 194 * Create a {@link Double} Tensor with data from the given buffer. 195 * 196 * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its 197 * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a 198 * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this 199 * method. 200 * 201 * @param shape the tensor shape. 202 * @param data a buffer containing the tensor data. 203 * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer 204 */ create(long[] shape, DoubleBuffer data)205 public static Tensor<Double> create(long[] shape, DoubleBuffer data) { 206 Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining()); 207 t.buffer().asDoubleBuffer().put(data); 208 return t; 209 } 210 211 /** 212 * Create an {@link Long} Tensor with data from the given buffer. 213 * 214 * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its 215 * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a 216 * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this 217 * method. 218 * 219 * @param shape the tensor shape. 220 * @param data a buffer containing the tensor data. 221 * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer 222 */ create(long[] shape, LongBuffer data)223 public static Tensor<Long> create(long[] shape, LongBuffer data) { 224 Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining()); 225 t.buffer().asLongBuffer().put(data); 226 return t; 227 } 228 229 /** 230 * Create a Tensor of any type with data from the given buffer. 231 * 232 * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been 233 * encoded into {@code data} as per the specification of the TensorFlow <a 234 * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C 235 * API</a>. 236 * 237 * @param <T> the tensor element type 238 * @param type the tensor element type, represented as a class object. 239 * @param shape the tensor shape. 240 * @param data a buffer containing the tensor data. 241 * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the 242 * buffer 243 */ create(Class<T> type, long[] shape, ByteBuffer data)244 public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) { 245 @SuppressWarnings("unchecked") 246 Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data); 247 return ret; 248 } 249 create(DataType dtype, long[] shape, ByteBuffer data)250 private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) { 251 int nremaining; 252 if (dtype != DataType.STRING) { 253 int elemBytes = elemByteSize(dtype); 254 if (data.remaining() % elemBytes != 0) { 255 throw new IllegalArgumentException( 256 String.format( 257 "ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)", 258 data.remaining(), dtype.toString(), elemBytes)); 259 } 260 nremaining = data.remaining() / elemBytes; 261 } else { 262 nremaining = data.remaining(); 263 } 264 Tensor<?> t = allocateForBuffer(dtype, shape, nremaining); 265 t.buffer().put(data); 266 return t; 267 } 268 269 /** 270 * Returns this Tensor object with the type {@code Tensor<U>}. This method is useful when given a 271 * value of type {@code Tensor<?>}. 272 * 273 * @param type any (non-null) array of the correct type. 274 * @throws IllegalArgumentException if the actual data type of this object does not match the type 275 * {@code U}. 276 */ 277 @SuppressWarnings("unchecked") expect(Class<U> type)278 public <U> Tensor<U> expect(Class<U> type) { 279 DataType dt = DataType.fromClass(type); 280 if (!dt.equals(dtype)) { 281 throw new IllegalArgumentException( 282 "Cannot cast from tensor of " + dtype + " to tensor of " + dt); 283 } 284 return ((Tensor<U>) this); 285 } 286 287 // Helper function to allocate a Tensor for the create() methods that create a Tensor from 288 // a java.nio.Buffer. 289 // Requires: dataType matches T allocateForBuffer(DataType dataType, long[] shape, int nBuffered)290 private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape, int nBuffered) { 291 final int nflattened = numElements(shape); 292 int nbytes = 0; 293 if (dataType != DataType.STRING) { 294 if (nBuffered != nflattened) { 295 throw incompatibleBuffer(nBuffered, shape); 296 } 297 nbytes = nflattened * elemByteSize(dataType); 298 } else { 299 // DT_STRING tensor encoded in a ByteBuffer. 300 nbytes = nBuffered; 301 } 302 Tensor<T> t = new Tensor<T>(dataType); 303 t.shapeCopy = Arrays.copyOf(shape, shape.length); 304 long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); 305 t.nativeRef = new NativeReference(nativeHandle); 306 return t; 307 } 308 309 /** 310 * Release resources associated with the Tensor. 311 * 312 * <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager 313 * operation or memory will be leaked. 314 * 315 * <p>The Tensor object is no longer usable after {@code close} returns. 316 */ 317 @Override close()318 public void close() { 319 nativeRef.release(); 320 } 321 322 /** Returns the {@link DataType} of elements stored in the Tensor. */ dataType()323 public DataType dataType() { 324 return dtype; 325 } 326 327 /** 328 * Returns the number of dimensions (sometimes referred to as <a 329 * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor. 330 * 331 * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc. 332 */ numDimensions()333 public int numDimensions() { 334 return shapeCopy.length; 335 } 336 337 /** Returns the size, in bytes, of the tensor data. */ numBytes()338 public int numBytes() { 339 return buffer().remaining(); 340 } 341 342 /** Returns the number of elements in a flattened (1-D) view of the tensor. */ numElements()343 public int numElements() { 344 return numElements(shapeCopy); 345 } 346 347 /** 348 * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of 349 * the Tensor, i.e., the sizes of each dimension. 350 * 351 * @return an array where the i-th element is the size of the i-th dimension of the tensor. 352 */ shape()353 public long[] shape() { 354 return shapeCopy; 355 } 356 357 /** 358 * Returns the value in a scalar {@link Float} tensor. 359 * 360 * @throws IllegalArgumentException if the Tensor does not represent a float scalar. 361 */ floatValue()362 public float floatValue() { 363 return scalarFloat(getNativeHandle()); 364 } 365 366 /** 367 * Returns the value in a scalar {@link Double} tensor. 368 * 369 * @throws IllegalArgumentException if the Tensor does not represent a double scalar. 370 */ doubleValue()371 public double doubleValue() { 372 return scalarDouble(getNativeHandle()); 373 } 374 375 /** 376 * Returns the value in a scalar {@link Integer} tensor. 377 * 378 * @throws IllegalArgumentException if the Tensor does not represent a int scalar. 379 */ intValue()380 public int intValue() { 381 return scalarInt(getNativeHandle()); 382 } 383 384 /** 385 * Returns the value in a scalar {@link Long} tensor. 386 * 387 * @throws IllegalArgumentException if the Tensor does not represent a long scalar. 388 */ longValue()389 public long longValue() { 390 return scalarLong(getNativeHandle()); 391 } 392 393 /** 394 * Returns the value in a scalar {@link Boolean} tensor. 395 * 396 * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. 397 */ booleanValue()398 public boolean booleanValue() { 399 return scalarBoolean(getNativeHandle()); 400 } 401 402 /** 403 * Returns the value in a scalar {@link String} tensor. 404 * 405 * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. 406 */ bytesValue()407 public byte[] bytesValue() { 408 return scalarBytes(getNativeHandle()); 409 } 410 411 /** 412 * Copies the contents of the tensor to {@code dst} and returns {@code dst}. 413 * 414 * <p>For non-scalar tensors, this method copies the contents of the underlying tensor to a Java 415 * array. For scalar tensors, use one of {@link #bytesValue()}, {@link #floatValue()}, {@link 416 * #doubleValue()}, {@link #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead. 417 * The type and shape of {@code dst} must be compatible with the tensor. For example: 418 * 419 * <pre>{@code 420 * int matrix[2][2] = {{1,2},{3,4}}; 421 * try(Tensor t = Tensor.create(matrix)) { 422 * // Succeeds and prints "3" 423 * int[][] copy = new int[2][2]; 424 * System.out.println(t.copyTo(copy)[1][0]); 425 * 426 * // Throws IllegalArgumentException since the shape of dst does not match the shape of t. 427 * int[][] dst = new int[4][1]; 428 * t.copyTo(dst); 429 * } 430 * }</pre> 431 * 432 * @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible 433 * with the tensor (for example, mismatched data types or shapes). 434 */ copyTo(U dst)435 public <U> U copyTo(U dst) { 436 throwExceptionIfTypeIsIncompatible(dst); 437 readNDArray(getNativeHandle(), dst); 438 return dst; 439 } 440 441 /** 442 * Write the data of a {@link Integer} tensor into the given buffer. 443 * 444 * <p>Copies {@code numElements()} elements to the buffer. 445 * 446 * @param dst the destination buffer 447 * @throws BufferOverflowException If there is insufficient space in the given buffer for the data 448 * in this tensor 449 * @throws IllegalArgumentException If the tensor data type is not {@link Integer} 450 */ writeTo(IntBuffer dst)451 public void writeTo(IntBuffer dst) { 452 if (dtype != DataType.INT32) { 453 throw incompatibleBuffer(dst, dtype); 454 } 455 ByteBuffer src = buffer(); 456 dst.put(src.asIntBuffer()); 457 } 458 459 /** 460 * Write the data of a {@link Float} tensor into the given buffer. 461 * 462 * <p>Copies {@code numElements()} elements to the buffer. 463 * 464 * @param dst the destination buffer 465 * @throws BufferOverflowException If there is insufficient space in the given buffer for the data 466 * in this tensor 467 * @throws IllegalArgumentException If the tensor datatype is not {@link Float} 468 */ writeTo(FloatBuffer dst)469 public void writeTo(FloatBuffer dst) { 470 if (dtype != DataType.FLOAT) { 471 throw incompatibleBuffer(dst, dtype); 472 } 473 ByteBuffer src = buffer(); 474 dst.put(src.asFloatBuffer()); 475 } 476 477 /** 478 * Write the data of a {@link Double} tensor into the given buffer. 479 * 480 * <p>Copies {@code numElements()} elements to the buffer. 481 * 482 * @param dst the destination buffer 483 * @throws BufferOverflowException If there is insufficient space in the given buffer for the data 484 * in this tensor 485 * @throws IllegalArgumentException If the tensor datatype is not {@link Double} 486 */ writeTo(DoubleBuffer dst)487 public void writeTo(DoubleBuffer dst) { 488 if (dtype != DataType.DOUBLE) { 489 throw incompatibleBuffer(dst, dtype); 490 } 491 ByteBuffer src = buffer(); 492 dst.put(src.asDoubleBuffer()); 493 } 494 495 /** 496 * Write the data of a {@link Long} tensor into the given buffer. 497 * 498 * <p>Copies {@code numElements()} elements to the buffer. 499 * 500 * @param dst the destination buffer 501 * @throws BufferOverflowException If there is insufficient space in the given buffer for the data 502 * in this tensor 503 * @throws IllegalArgumentException If the tensor datatype is not {@link Long} 504 */ writeTo(LongBuffer dst)505 public void writeTo(LongBuffer dst) { 506 if (dtype != DataType.INT64) { 507 throw incompatibleBuffer(dst, dtype); 508 } 509 ByteBuffer src = buffer(); 510 dst.put(src.asLongBuffer()); 511 } 512 513 /** 514 * Write the tensor data into the given buffer. 515 * 516 * <p>Copies {@code numBytes()} bytes to the buffer in native byte order for primitive types. 517 * 518 * @param dst the destination buffer 519 * @throws BufferOverflowException If there is insufficient space in the given buffer for the data 520 * in this tensor 521 */ writeTo(ByteBuffer dst)522 public void writeTo(ByteBuffer dst) { 523 ByteBuffer src = buffer(); 524 dst.put(src); 525 } 526 527 /** Returns a string describing the type and shape of the Tensor. */ 528 @Override toString()529 public String toString() { 530 return String.format("%s tensor with shape %s", dtype.toString(), Arrays.toString(shape())); 531 } 532 533 /** 534 * Create a Tensor object from a handle to the C TF_Tensor object. 535 * 536 * <p>Takes ownership of the handle. 537 */ fromHandle(long handle)538 static Tensor<?> fromHandle(long handle) { 539 @SuppressWarnings("rawtypes") 540 Tensor<?> t = new Tensor(DataType.fromC(dtype(handle))); 541 t.shapeCopy = shape(handle); 542 t.nativeRef = new NativeReference(handle); 543 return t; 544 } 545 546 /** 547 * Create an eager Tensor object from a handle to the C TF_Tensor object. 548 * 549 * <p>Takes ownership of the handle. 550 */ fromHandle(long handle, EagerSession session)551 static Tensor<?> fromHandle(long handle, EagerSession session) { 552 Tensor<?> t = fromHandle(handle); 553 t.nativeRef.eager(session, t); 554 return t; 555 } 556 getNativeHandle()557 long getNativeHandle() { 558 return nativeRef.tensorHandle; 559 } 560 561 private NativeReference nativeRef = null; 562 private final DataType dtype; 563 private long[] shapeCopy = null; 564 Tensor(DataType t)565 private Tensor(DataType t) { 566 dtype = t; 567 } 568 buffer()569 private ByteBuffer buffer() { 570 return buffer(getNativeHandle()).order(ByteOrder.nativeOrder()); 571 } 572 incompatibleBuffer(Buffer buf, DataType dataType)573 private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) { 574 return new IllegalArgumentException( 575 String.format("cannot use %s with Tensor of type %s", buf.getClass().getName(), dataType)); 576 } 577 incompatibleBuffer(int numElements, long[] shape)578 private static IllegalArgumentException incompatibleBuffer(int numElements, long[] shape) { 579 return new IllegalArgumentException( 580 String.format( 581 "buffer with %d elements is not compatible with a Tensor with shape %s", 582 numElements, Arrays.toString(shape))); 583 } 584 numElements(long[] shape)585 private static int numElements(long[] shape) { 586 // assumes a fully-known shape 587 int n = 1; 588 for (int i = 0; i < shape.length; i++) { 589 n *= (int) shape[i]; 590 } 591 return n; 592 } 593 elemByteSize(DataType dataType)594 private static int elemByteSize(DataType dataType) { 595 int size = dataType.byteSize(); 596 if (size < 0) { 597 throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); 598 } 599 return size; 600 } 601 throwExceptionIfNotByteOfByteArrays(Object array)602 private static void throwExceptionIfNotByteOfByteArrays(Object array) { 603 if (!array.getClass().getName().equals("[[B")) { 604 throw new IllegalArgumentException( 605 "object cannot be converted to a Tensor as it includes an array with null elements"); 606 } 607 } 608 609 /** 610 * Reference to the underlying native tensor 611 * 612 * <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get 613 * automatically released after executing the last line of the `try` block they were declared in. 614 * 615 * <p>They can also be attached to an eager session, where in this case their lifetime ends either 616 * when this session is closed or when the Tensor instance is no longer referenced and have been 617 * garbage-collected. 618 * 619 * <p>This helper class wraps the tensor native handle and support both situations; If an eager 620 * reference to the tensor exists, it will take care of releasing the tensor at the end of its 621 * life. If the tensor is being explicitly closed before this happens, it will take cake of 622 * clearing its association with any eager session before cleaning up the resources. 623 */ 624 private static class NativeReference { 625 626 /** Attaches this reference to an eager session */ 627 private class EagerReference extends EagerSession.NativeReference { 628 EagerReference(EagerSession session, Tensor<?> tensor)629 EagerReference(EagerSession session, Tensor<?> tensor) { 630 super(session, tensor); 631 } 632 633 @Override delete()634 void delete() { 635 // Mark this eager reference as cleared since it has been deleted by the session 636 NativeReference.this.eagerRef = null; 637 NativeReference.this.release(); 638 } 639 } 640 NativeReference(long tensorHandle)641 NativeReference(long tensorHandle) { 642 this.tensorHandle = tensorHandle; 643 } 644 eager(EagerSession session, Tensor<?> tensor)645 void eager(EagerSession session, Tensor<?> tensor) { 646 if (eagerRef != null) { 647 throw new IllegalStateException("The tensor is already attached to an eager session"); 648 } 649 eagerRef = new EagerReference(session, tensor); 650 } 651 release()652 synchronized void release() { 653 if (tensorHandle != 0L) { 654 // Clear any remaining eager reference to this tensor 655 if (eagerRef != null) { 656 eagerRef.clear(); 657 eagerRef = null; 658 } 659 Tensor.delete(tensorHandle); 660 tensorHandle = 0L; 661 } 662 } 663 664 private long tensorHandle; 665 private EagerReference eagerRef; 666 } 667 668 private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>(); 669 670 static { classDataTypes.put(int.class, DataType.INT32)671 classDataTypes.put(int.class, DataType.INT32); classDataTypes.put(Integer.class, DataType.INT32)672 classDataTypes.put(Integer.class, DataType.INT32); classDataTypes.put(long.class, DataType.INT64)673 classDataTypes.put(long.class, DataType.INT64); classDataTypes.put(Long.class, DataType.INT64)674 classDataTypes.put(Long.class, DataType.INT64); classDataTypes.put(float.class, DataType.FLOAT)675 classDataTypes.put(float.class, DataType.FLOAT); classDataTypes.put(Float.class, DataType.FLOAT)676 classDataTypes.put(Float.class, DataType.FLOAT); classDataTypes.put(double.class, DataType.DOUBLE)677 classDataTypes.put(double.class, DataType.DOUBLE); classDataTypes.put(Double.class, DataType.DOUBLE)678 classDataTypes.put(Double.class, DataType.DOUBLE); classDataTypes.put(byte.class, DataType.STRING)679 classDataTypes.put(byte.class, DataType.STRING); classDataTypes.put(Byte.class, DataType.STRING)680 classDataTypes.put(Byte.class, DataType.STRING); classDataTypes.put(boolean.class, DataType.BOOL)681 classDataTypes.put(boolean.class, DataType.BOOL); classDataTypes.put(Boolean.class, DataType.BOOL)682 classDataTypes.put(Boolean.class, DataType.BOOL); 683 } 684 685 /** The class for the data type to which Java object o corresponds. */ baseObjType(Object o)686 private static Class<?> baseObjType(Object o) { 687 Class<?> c = o.getClass(); 688 while (c.isArray()) { 689 c = c.getComponentType(); 690 } 691 return c; 692 } 693 694 /** 695 * The default TensorFlow data type to which Java object o corresponds. Some Java objects 696 * represent more than one TensorFlow data type; for example, 'byte' can represent both {@code 697 * uint8} and {@code string}, with the latter being the default interpretation. 698 */ dataTypeOf(Object o)699 private static DataType dataTypeOf(Object o) { 700 Class<?> c = baseObjType(o); 701 return dataTypeFromClass(c); 702 } 703 dataTypeFromClass(Class<?> c)704 private static DataType dataTypeFromClass(Class<?> c) { 705 DataType ret = classDataTypes.get(c); 706 if (ret != null) { 707 return ret; 708 } 709 throw new IllegalArgumentException("cannot create Tensors of type " + c.getName()); 710 } 711 712 /** 713 * Return the number of dimensions of the tensor that object {@code o} represents as a tensor 714 * whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o 715 * itself, but is one smaller for tensors of strings. 716 * 717 * @param o The object to inspect. It must be a valid representation of the given data type. 718 * @param dtype The expected data type of the tensor. 719 */ numDimensions(Object o, DataType dtype)720 private static int numDimensions(Object o, DataType dtype) { 721 int ret = numArrayDimensions(o); 722 if (dtype == DataType.STRING && ret > 0) { 723 return ret - 1; 724 } 725 return ret; 726 } 727 728 /** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */ numArrayDimensions(Object o)729 private static int numArrayDimensions(Object o) { 730 Class<?> c = o.getClass(); 731 int i = 0; 732 while (c.isArray()) { 733 c = c.getComponentType(); 734 i++; 735 } 736 return i; 737 } 738 739 /** 740 * Fills in the remaining entries in the shape array starting from position {@code dim} with the 741 * dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have 742 * sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise. 743 */ fillShape(Object o, int dim, long[] shape)744 private static void fillShape(Object o, int dim, long[] shape) { 745 if (shape == null || dim == shape.length) { 746 return; 747 } 748 final int len = Array.getLength(o); 749 if (len == 0) { 750 throw new IllegalArgumentException("cannot create Tensors with a 0 dimension"); 751 } 752 if (shape[dim] == 0) { 753 shape[dim] = len; 754 } else if (shape[dim] != len) { 755 throw new IllegalArgumentException( 756 String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); 757 } 758 for (int i = 0; i < len; ++i) { 759 fillShape(Array.get(o, i), dim + 1, shape); 760 } 761 } 762 763 /** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */ objectCompatWithType(Object obj, DataType dtype)764 private static boolean objectCompatWithType(Object obj, DataType dtype) { 765 Class<?> c = baseObjType(obj); 766 DataType dto = dataTypeFromClass(c); 767 int nd = numDimensions(obj, dto); 768 if (!c.isPrimitive() && c != String.class && nd != 0) { 769 throw new IllegalArgumentException( 770 "cannot create non-scalar Tensors from arrays of boxed values"); 771 } 772 if (dto.equals(dtype)) { 773 return true; 774 } 775 if (dto == DataType.STRING && dtype == DataType.UINT8) { 776 return true; 777 } 778 return false; 779 } 780 throwExceptionIfTypeIsIncompatible(Object o)781 private void throwExceptionIfTypeIsIncompatible(Object o) { 782 final int rank = numDimensions(); 783 final int oRank = numDimensions(o, dtype); 784 if (oRank != rank) { 785 throw new IllegalArgumentException( 786 String.format( 787 "cannot copy Tensor with %d dimensions into an object with %d", rank, oRank)); 788 } 789 if (!objectCompatWithType(o, dtype)) { 790 throw new IllegalArgumentException( 791 String.format( 792 "cannot copy Tensor with DataType %s into an object of type %s", 793 dtype.toString(), o.getClass().getName())); 794 } 795 long[] oShape = new long[rank]; 796 fillShape(o, 0, oShape); 797 for (int i = 0; i < oShape.length; ++i) { 798 if (oShape[i] != shape()[i]) { 799 throw new IllegalArgumentException( 800 String.format( 801 "cannot copy Tensor with shape %s into object with shape %s", 802 Arrays.toString(shape()), Arrays.toString(oShape))); 803 } 804 } 805 } 806 allocate(int dtype, long[] shape, long byteSize)807 private static native long allocate(int dtype, long[] shape, long byteSize); 808 allocateScalarBytes(byte[] value)809 private static native long allocateScalarBytes(byte[] value); 810 allocateNonScalarBytes(long[] shape, Object[] value)811 private static native long allocateNonScalarBytes(long[] shape, Object[] value); 812 delete(long handle)813 private static native void delete(long handle); 814 buffer(long handle)815 private static native ByteBuffer buffer(long handle); 816 dtype(long handle)817 private static native int dtype(long handle); 818 shape(long handle)819 private static native long[] shape(long handle); 820 setValue(long handle, Object value)821 private static native void setValue(long handle, Object value); 822 scalarFloat(long handle)823 private static native float scalarFloat(long handle); 824 scalarDouble(long handle)825 private static native double scalarDouble(long handle); 826 scalarInt(long handle)827 private static native int scalarInt(long handle); 828 scalarLong(long handle)829 private static native long scalarLong(long handle); 830 scalarBoolean(long handle)831 private static native boolean scalarBoolean(long handle); 832 scalarBytes(long handle)833 private static native byte[] scalarBytes(long handle); 834 readNDArray(long handle, Object value)835 private static native void readNDArray(long handle, Object value); 836 837 static { TensorFlow.init()838 TensorFlow.init(); 839 } 840 } 841