xref: /aosp_15_r20/external/tensorflow/tensorflow/java/src/main/java/org/tensorflow/Tensor.java (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import java.lang.reflect.Array;
19 import java.nio.Buffer;
20 import java.nio.BufferOverflowException;
21 import java.nio.ByteBuffer;
22 import java.nio.ByteOrder;
23 import java.nio.DoubleBuffer;
24 import java.nio.FloatBuffer;
25 import java.nio.IntBuffer;
26 import java.nio.LongBuffer;
27 import java.util.Arrays;
28 import java.util.HashMap;
29 
30 /**
31  * A statically typed multi-dimensional array whose elements are of a type described by T.
32  *
33  * <p>Instances of a Tensor are <b>not</b> thread-safe.
34  *
35  * <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
36  * invoking the {@link #close()} method when the object is no longer needed. For example, using a
37  * try-with-resources block:
38  *
39  * <pre>{@code
40  * try (Tensor t = Tensor.create(...)) {
41  *   doSomethingWith(t);
42  * }
43  * }</pre>
44  */
45 public final class Tensor<T> implements AutoCloseable {
46 
47   /**
48    * Creates a Tensor from a Java object.
49    *
50    * <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types. Not all
51    * Java objects can be converted to a {@code Tensor}. In particular, the argument {@code obj} must
52    * be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of
53    * one of those primitives. The argument {@code type} specifies how to interpret the first
54    * argument as a TensorFlow type. For example:
55    *
56    * <pre>{@code
57    * // Valid: A 64-bit integer scalar.
58    * Tensor<Long> s = Tensor.create(42L, Long.class);
59    *
60    * // Valid: A 3x2 matrix of floats.
61    * float[][] matrix = new float[3][2];
62    * Tensor<Float> m = Tensor.create(matrix, Float.class);
63    *
64    * // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
65    * // does not fit into the TensorFlow type system.
66    * Tensor<?> o = Tensor.create(new Object())
67    *
68    * // Invalid: Will throw an IllegalArgumentException since there are
69    * // a differing number of elements in each row of this 2-D array.
70    * int[][] twoD = new int[2][];
71    * twoD[0] = new int[1];
72    * twoD[1] = new int[2];
73    * Tensor<Integer> x = Tensor.create(twoD, Integer.class);
74    * }</pre>
75    *
76    * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
77    * be initialized from arrays of {@code byte[]} elements. For example:
78    *
79    * <pre>{@code
80    * // Valid: A String tensor.
81    * Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
82    *
83    * // Java Strings will need to be encoded into a byte-sequence.
84    * String mystring = "foo";
85    * Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
86    *
87    * // Valid: Matrix of String tensors.
88    * // Each element might have a different length.
89    * byte[][][] matrix = new byte[2][2][];
90    * matrix[0][0] = "this".getBytes("UTF-8");
91    * matrix[0][1] = "is".getBytes("UTF-8");
92    * matrix[1][0] = "a".getBytes("UTF-8");
93    * matrix[1][1] = "matrix".getBytes("UTF-8");
94    * Tensor<String> m = Tensor.create(matrix, String.class);
95    * }</pre>
96    *
97    * @param obj The object to convert to a {@code Tensor<T>}. Note that whether it is compatible
98    *     with the type T is not checked by the type system. For type-safe creation of tensors, use
99    *     {@link Tensors}.
100    * @param type The class object representing the type T.
101    * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
102    *     system.
103    */
104   @SuppressWarnings("unchecked")
create(Object obj, Class<T> type)105   public static <T> Tensor<T> create(Object obj, Class<T> type) {
106     DataType dtype = DataType.fromClass(type);
107     if (!objectCompatWithType(obj, dtype)) {
108       throw new IllegalArgumentException(
109           "DataType of object does not match T (expected "
110               + dtype
111               + ", got "
112               + dataTypeOf(obj)
113               + ")");
114     }
115     return (Tensor<T>) create(obj, dtype);
116   }
117 
118   /**
119    * Creates a tensor from an object whose class is inspected to figure out what the underlying data
120    * type should be.
121    *
122    * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
123    *     system.
124    */
create(Object obj)125   public static Tensor<?> create(Object obj) {
126     return create(obj, dataTypeOf(obj));
127   }
128 
129   /**
130    * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
131    * to match {@code type}, but this condition is not checked.
132    *
133    * @param obj the object supplying the tensor data.
134    * @param dtype the data type of the tensor to create. It must be compatible with the run-time
135    *     type of the object.
136    * @return the new tensor
137    */
create(Object obj, DataType dtype)138   private static Tensor<?> create(Object obj, DataType dtype) {
139     @SuppressWarnings("rawtypes")
140     Tensor<?> t = new Tensor(dtype);
141     t.shapeCopy = new long[numDimensions(obj, dtype)];
142     fillShape(obj, 0, t.shapeCopy);
143     long nativeHandle;
144     if (t.dtype != DataType.STRING) {
145       int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
146       nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
147       setValue(nativeHandle, obj);
148     } else if (t.shapeCopy.length != 0) {
149       nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
150     } else {
151       nativeHandle = allocateScalarBytes((byte[]) obj);
152     }
153     t.nativeRef = new NativeReference(nativeHandle);
154     return t;
155   }
156 
157   /**
158    * Create a {@link Integer} Tensor with data from the given buffer.
159    *
160    * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
161    * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
162    * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
163    * method.
164    *
165    * @param shape the tensor shape.
166    * @param data a buffer containing the tensor data.
167    * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
168    */
create(long[] shape, IntBuffer data)169   public static Tensor<Integer> create(long[] shape, IntBuffer data) {
170     Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
171     t.buffer().asIntBuffer().put(data);
172     return t;
173   }
174 
175   /**
176    * Create a {@link Float} Tensor with data from the given buffer.
177    *
178    * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
179    * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
180    * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
181    * method.
182    *
183    * @param shape the tensor shape.
184    * @param data a buffer containing the tensor data.
185    * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
186    */
create(long[] shape, FloatBuffer data)187   public static Tensor<Float> create(long[] shape, FloatBuffer data) {
188     Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
189     t.buffer().asFloatBuffer().put(data);
190     return t;
191   }
192 
193   /**
194    * Create a {@link Double} Tensor with data from the given buffer.
195    *
196    * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
197    * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
198    * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
199    * method.
200    *
201    * @param shape the tensor shape.
202    * @param data a buffer containing the tensor data.
203    * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
204    */
create(long[] shape, DoubleBuffer data)205   public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
206     Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
207     t.buffer().asDoubleBuffer().put(data);
208     return t;
209   }
210 
211   /**
212    * Create an {@link Long} Tensor with data from the given buffer.
213    *
214    * <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
215    * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
216    * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
217    * method.
218    *
219    * @param shape the tensor shape.
220    * @param data a buffer containing the tensor data.
221    * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
222    */
create(long[] shape, LongBuffer data)223   public static Tensor<Long> create(long[] shape, LongBuffer data) {
224     Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
225     t.buffer().asLongBuffer().put(data);
226     return t;
227   }
228 
229   /**
230    * Create a Tensor of any type with data from the given buffer.
231    *
232    * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
233    * encoded into {@code data} as per the specification of the TensorFlow <a
234    * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C
235    * API</a>.
236    *
237    * @param <T> the tensor element type
238    * @param type the tensor element type, represented as a class object.
239    * @param shape the tensor shape.
240    * @param data a buffer containing the tensor data.
241    * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
242    *     buffer
243    */
create(Class<T> type, long[] shape, ByteBuffer data)244   public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
245     @SuppressWarnings("unchecked")
246     Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
247     return ret;
248   }
249 
create(DataType dtype, long[] shape, ByteBuffer data)250   private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
251     int nremaining;
252     if (dtype != DataType.STRING) {
253       int elemBytes = elemByteSize(dtype);
254       if (data.remaining() % elemBytes != 0) {
255         throw new IllegalArgumentException(
256             String.format(
257                 "ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)",
258                 data.remaining(), dtype.toString(), elemBytes));
259       }
260       nremaining = data.remaining() / elemBytes;
261     } else {
262       nremaining = data.remaining();
263     }
264     Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
265     t.buffer().put(data);
266     return t;
267   }
268 
269   /**
270    * Returns this Tensor object with the type {@code Tensor<U>}. This method is useful when given a
271    * value of type {@code Tensor<?>}.
272    *
273    * @param type any (non-null) array of the correct type.
274    * @throws IllegalArgumentException if the actual data type of this object does not match the type
275    *     {@code U}.
276    */
277   @SuppressWarnings("unchecked")
expect(Class<U> type)278   public <U> Tensor<U> expect(Class<U> type) {
279     DataType dt = DataType.fromClass(type);
280     if (!dt.equals(dtype)) {
281       throw new IllegalArgumentException(
282           "Cannot cast from tensor of " + dtype + " to tensor of " + dt);
283     }
284     return ((Tensor<U>) this);
285   }
286 
287   // Helper function to allocate a Tensor for the create() methods that create a Tensor from
288   // a java.nio.Buffer.
289   // Requires: dataType matches T
allocateForBuffer(DataType dataType, long[] shape, int nBuffered)290   private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
291     final int nflattened = numElements(shape);
292     int nbytes = 0;
293     if (dataType != DataType.STRING) {
294       if (nBuffered != nflattened) {
295         throw incompatibleBuffer(nBuffered, shape);
296       }
297       nbytes = nflattened * elemByteSize(dataType);
298     } else {
299       // DT_STRING tensor encoded in a ByteBuffer.
300       nbytes = nBuffered;
301     }
302     Tensor<T> t = new Tensor<T>(dataType);
303     t.shapeCopy = Arrays.copyOf(shape, shape.length);
304     long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
305     t.nativeRef = new NativeReference(nativeHandle);
306     return t;
307   }
308 
309   /**
310    * Release resources associated with the Tensor.
311    *
312    * <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
313    * operation or memory will be leaked.
314    *
315    * <p>The Tensor object is no longer usable after {@code close} returns.
316    */
317   @Override
close()318   public void close() {
319     nativeRef.release();
320   }
321 
322   /** Returns the {@link DataType} of elements stored in the Tensor. */
dataType()323   public DataType dataType() {
324     return dtype;
325   }
326 
327   /**
328    * Returns the number of dimensions (sometimes referred to as <a
329    * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
330    *
331    * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
332    */
numDimensions()333   public int numDimensions() {
334     return shapeCopy.length;
335   }
336 
337   /** Returns the size, in bytes, of the tensor data. */
numBytes()338   public int numBytes() {
339     return buffer().remaining();
340   }
341 
342   /** Returns the number of elements in a flattened (1-D) view of the tensor. */
numElements()343   public int numElements() {
344     return numElements(shapeCopy);
345   }
346 
347   /**
348    * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
349    * the Tensor, i.e., the sizes of each dimension.
350    *
351    * @return an array where the i-th element is the size of the i-th dimension of the tensor.
352    */
shape()353   public long[] shape() {
354     return shapeCopy;
355   }
356 
357   /**
358    * Returns the value in a scalar {@link Float} tensor.
359    *
360    * @throws IllegalArgumentException if the Tensor does not represent a float scalar.
361    */
floatValue()362   public float floatValue() {
363     return scalarFloat(getNativeHandle());
364   }
365 
366   /**
367    * Returns the value in a scalar {@link Double} tensor.
368    *
369    * @throws IllegalArgumentException if the Tensor does not represent a double scalar.
370    */
doubleValue()371   public double doubleValue() {
372     return scalarDouble(getNativeHandle());
373   }
374 
375   /**
376    * Returns the value in a scalar {@link Integer} tensor.
377    *
378    * @throws IllegalArgumentException if the Tensor does not represent a int scalar.
379    */
intValue()380   public int intValue() {
381     return scalarInt(getNativeHandle());
382   }
383 
384   /**
385    * Returns the value in a scalar {@link Long} tensor.
386    *
387    * @throws IllegalArgumentException if the Tensor does not represent a long scalar.
388    */
longValue()389   public long longValue() {
390     return scalarLong(getNativeHandle());
391   }
392 
393   /**
394    * Returns the value in a scalar {@link Boolean} tensor.
395    *
396    * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
397    */
booleanValue()398   public boolean booleanValue() {
399     return scalarBoolean(getNativeHandle());
400   }
401 
402   /**
403    * Returns the value in a scalar {@link String} tensor.
404    *
405    * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
406    */
bytesValue()407   public byte[] bytesValue() {
408     return scalarBytes(getNativeHandle());
409   }
410 
411   /**
412    * Copies the contents of the tensor to {@code dst} and returns {@code dst}.
413    *
414    * <p>For non-scalar tensors, this method copies the contents of the underlying tensor to a Java
415    * array. For scalar tensors, use one of {@link #bytesValue()}, {@link #floatValue()}, {@link
416    * #doubleValue()}, {@link #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead.
417    * The type and shape of {@code dst} must be compatible with the tensor. For example:
418    *
419    * <pre>{@code
420    * int matrix[2][2] = {{1,2},{3,4}};
421    * try(Tensor t = Tensor.create(matrix)) {
422    *   // Succeeds and prints "3"
423    *   int[][] copy = new int[2][2];
424    *   System.out.println(t.copyTo(copy)[1][0]);
425    *
426    *   // Throws IllegalArgumentException since the shape of dst does not match the shape of t.
427    *   int[][] dst = new int[4][1];
428    *   t.copyTo(dst);
429    * }
430    * }</pre>
431    *
432    * @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible
433    *     with the tensor (for example, mismatched data types or shapes).
434    */
copyTo(U dst)435   public <U> U copyTo(U dst) {
436     throwExceptionIfTypeIsIncompatible(dst);
437     readNDArray(getNativeHandle(), dst);
438     return dst;
439   }
440 
441   /**
442    * Write the data of a {@link Integer} tensor into the given buffer.
443    *
444    * <p>Copies {@code numElements()} elements to the buffer.
445    *
446    * @param dst the destination buffer
447    * @throws BufferOverflowException If there is insufficient space in the given buffer for the data
448    *     in this tensor
449    * @throws IllegalArgumentException If the tensor data type is not {@link Integer}
450    */
writeTo(IntBuffer dst)451   public void writeTo(IntBuffer dst) {
452     if (dtype != DataType.INT32) {
453       throw incompatibleBuffer(dst, dtype);
454     }
455     ByteBuffer src = buffer();
456     dst.put(src.asIntBuffer());
457   }
458 
459   /**
460    * Write the data of a {@link Float} tensor into the given buffer.
461    *
462    * <p>Copies {@code numElements()} elements to the buffer.
463    *
464    * @param dst the destination buffer
465    * @throws BufferOverflowException If there is insufficient space in the given buffer for the data
466    *     in this tensor
467    * @throws IllegalArgumentException If the tensor datatype is not {@link Float}
468    */
writeTo(FloatBuffer dst)469   public void writeTo(FloatBuffer dst) {
470     if (dtype != DataType.FLOAT) {
471       throw incompatibleBuffer(dst, dtype);
472     }
473     ByteBuffer src = buffer();
474     dst.put(src.asFloatBuffer());
475   }
476 
477   /**
478    * Write the data of a {@link Double} tensor into the given buffer.
479    *
480    * <p>Copies {@code numElements()} elements to the buffer.
481    *
482    * @param dst the destination buffer
483    * @throws BufferOverflowException If there is insufficient space in the given buffer for the data
484    *     in this tensor
485    * @throws IllegalArgumentException If the tensor datatype is not {@link Double}
486    */
writeTo(DoubleBuffer dst)487   public void writeTo(DoubleBuffer dst) {
488     if (dtype != DataType.DOUBLE) {
489       throw incompatibleBuffer(dst, dtype);
490     }
491     ByteBuffer src = buffer();
492     dst.put(src.asDoubleBuffer());
493   }
494 
495   /**
496    * Write the data of a {@link Long} tensor into the given buffer.
497    *
498    * <p>Copies {@code numElements()} elements to the buffer.
499    *
500    * @param dst the destination buffer
501    * @throws BufferOverflowException If there is insufficient space in the given buffer for the data
502    *     in this tensor
503    * @throws IllegalArgumentException If the tensor datatype is not {@link Long}
504    */
writeTo(LongBuffer dst)505   public void writeTo(LongBuffer dst) {
506     if (dtype != DataType.INT64) {
507       throw incompatibleBuffer(dst, dtype);
508     }
509     ByteBuffer src = buffer();
510     dst.put(src.asLongBuffer());
511   }
512 
513   /**
514    * Write the tensor data into the given buffer.
515    *
516    * <p>Copies {@code numBytes()} bytes to the buffer in native byte order for primitive types.
517    *
518    * @param dst the destination buffer
519    * @throws BufferOverflowException If there is insufficient space in the given buffer for the data
520    *     in this tensor
521    */
writeTo(ByteBuffer dst)522   public void writeTo(ByteBuffer dst) {
523     ByteBuffer src = buffer();
524     dst.put(src);
525   }
526 
527   /** Returns a string describing the type and shape of the Tensor. */
528   @Override
toString()529   public String toString() {
530     return String.format("%s tensor with shape %s", dtype.toString(), Arrays.toString(shape()));
531   }
532 
533   /**
534    * Create a Tensor object from a handle to the C TF_Tensor object.
535    *
536    * <p>Takes ownership of the handle.
537    */
fromHandle(long handle)538   static Tensor<?> fromHandle(long handle) {
539     @SuppressWarnings("rawtypes")
540     Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
541     t.shapeCopy = shape(handle);
542     t.nativeRef = new NativeReference(handle);
543     return t;
544   }
545 
546   /**
547    * Create an eager Tensor object from a handle to the C TF_Tensor object.
548    *
549    * <p>Takes ownership of the handle.
550    */
fromHandle(long handle, EagerSession session)551   static Tensor<?> fromHandle(long handle, EagerSession session) {
552     Tensor<?> t = fromHandle(handle);
553     t.nativeRef.eager(session, t);
554     return t;
555   }
556 
getNativeHandle()557   long getNativeHandle() {
558     return nativeRef.tensorHandle;
559   }
560 
561   private NativeReference nativeRef = null;
562   private final DataType dtype;
563   private long[] shapeCopy = null;
564 
Tensor(DataType t)565   private Tensor(DataType t) {
566     dtype = t;
567   }
568 
buffer()569   private ByteBuffer buffer() {
570     return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
571   }
572 
incompatibleBuffer(Buffer buf, DataType dataType)573   private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
574     return new IllegalArgumentException(
575         String.format("cannot use %s with Tensor of type %s", buf.getClass().getName(), dataType));
576   }
577 
incompatibleBuffer(int numElements, long[] shape)578   private static IllegalArgumentException incompatibleBuffer(int numElements, long[] shape) {
579     return new IllegalArgumentException(
580         String.format(
581             "buffer with %d elements is not compatible with a Tensor with shape %s",
582             numElements, Arrays.toString(shape)));
583   }
584 
numElements(long[] shape)585   private static int numElements(long[] shape) {
586     // assumes a fully-known shape
587     int n = 1;
588     for (int i = 0; i < shape.length; i++) {
589       n *= (int) shape[i];
590     }
591     return n;
592   }
593 
elemByteSize(DataType dataType)594   private static int elemByteSize(DataType dataType) {
595     int size = dataType.byteSize();
596     if (size < 0) {
597         throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
598     }
599     return size;
600   }
601 
throwExceptionIfNotByteOfByteArrays(Object array)602   private static void throwExceptionIfNotByteOfByteArrays(Object array) {
603     if (!array.getClass().getName().equals("[[B")) {
604       throw new IllegalArgumentException(
605           "object cannot be converted to a Tensor as it includes an array with null elements");
606     }
607   }
608 
609   /**
610    * Reference to the underlying native tensor
611    *
612    * <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get
613    * automatically released after executing the last line of the `try` block they were declared in.
614    *
615    * <p>They can also be attached to an eager session, where in this case their lifetime ends either
616    * when this session is closed or when the Tensor instance is no longer referenced and have been
617    * garbage-collected.
618    *
619    * <p>This helper class wraps the tensor native handle and support both situations; If an eager
620    * reference to the tensor exists, it will take care of releasing the tensor at the end of its
621    * life. If the tensor is being explicitly closed before this happens, it will take cake of
622    * clearing its association with any eager session before cleaning up the resources.
623    */
624   private static class NativeReference {
625 
626     /** Attaches this reference to an eager session */
627     private class EagerReference extends EagerSession.NativeReference {
628 
EagerReference(EagerSession session, Tensor<?> tensor)629       EagerReference(EagerSession session, Tensor<?> tensor) {
630         super(session, tensor);
631       }
632 
633       @Override
delete()634       void delete() {
635         // Mark this eager reference as cleared since it has been deleted by the session
636         NativeReference.this.eagerRef = null;
637         NativeReference.this.release();
638       }
639     }
640 
NativeReference(long tensorHandle)641     NativeReference(long tensorHandle) {
642       this.tensorHandle = tensorHandle;
643     }
644 
eager(EagerSession session, Tensor<?> tensor)645     void eager(EagerSession session, Tensor<?> tensor) {
646       if (eagerRef != null) {
647         throw new IllegalStateException("The tensor is already attached to an eager session");
648       }
649       eagerRef = new EagerReference(session, tensor);
650     }
651 
release()652     synchronized void release() {
653       if (tensorHandle != 0L) {
654         // Clear any remaining eager reference to this tensor
655         if (eagerRef != null) {
656           eagerRef.clear();
657           eagerRef = null;
658         }
659         Tensor.delete(tensorHandle);
660         tensorHandle = 0L;
661       }
662     }
663 
664     private long tensorHandle;
665     private EagerReference eagerRef;
666   }
667 
668   private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
669 
670   static {
classDataTypes.put(int.class, DataType.INT32)671     classDataTypes.put(int.class, DataType.INT32);
classDataTypes.put(Integer.class, DataType.INT32)672     classDataTypes.put(Integer.class, DataType.INT32);
classDataTypes.put(long.class, DataType.INT64)673     classDataTypes.put(long.class, DataType.INT64);
classDataTypes.put(Long.class, DataType.INT64)674     classDataTypes.put(Long.class, DataType.INT64);
classDataTypes.put(float.class, DataType.FLOAT)675     classDataTypes.put(float.class, DataType.FLOAT);
classDataTypes.put(Float.class, DataType.FLOAT)676     classDataTypes.put(Float.class, DataType.FLOAT);
classDataTypes.put(double.class, DataType.DOUBLE)677     classDataTypes.put(double.class, DataType.DOUBLE);
classDataTypes.put(Double.class, DataType.DOUBLE)678     classDataTypes.put(Double.class, DataType.DOUBLE);
classDataTypes.put(byte.class, DataType.STRING)679     classDataTypes.put(byte.class, DataType.STRING);
classDataTypes.put(Byte.class, DataType.STRING)680     classDataTypes.put(Byte.class, DataType.STRING);
classDataTypes.put(boolean.class, DataType.BOOL)681     classDataTypes.put(boolean.class, DataType.BOOL);
classDataTypes.put(Boolean.class, DataType.BOOL)682     classDataTypes.put(Boolean.class, DataType.BOOL);
683   }
684 
685   /** The class for the data type to which Java object o corresponds. */
baseObjType(Object o)686   private static Class<?> baseObjType(Object o) {
687     Class<?> c = o.getClass();
688     while (c.isArray()) {
689       c = c.getComponentType();
690     }
691     return c;
692   }
693 
694   /**
695    * The default TensorFlow data type to which Java object o corresponds. Some Java objects
696    * represent more than one TensorFlow data type; for example, 'byte' can represent both {@code
697    * uint8} and {@code string}, with the latter being the default interpretation.
698    */
dataTypeOf(Object o)699   private static DataType dataTypeOf(Object o) {
700     Class<?> c = baseObjType(o);
701     return dataTypeFromClass(c);
702   }
703 
dataTypeFromClass(Class<?> c)704   private static DataType dataTypeFromClass(Class<?> c) {
705     DataType ret = classDataTypes.get(c);
706     if (ret != null) {
707       return ret;
708     }
709     throw new IllegalArgumentException("cannot create Tensors of type " + c.getName());
710   }
711 
712   /**
713    * Return the number of dimensions of the tensor that object {@code o} represents as a tensor
714    * whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o
715    * itself, but is one smaller for tensors of strings.
716    *
717    * @param o The object to inspect. It must be a valid representation of the given data type.
718    * @param dtype The expected data type of the tensor.
719    */
numDimensions(Object o, DataType dtype)720   private static int numDimensions(Object o, DataType dtype) {
721     int ret = numArrayDimensions(o);
722     if (dtype == DataType.STRING && ret > 0) {
723       return ret - 1;
724     }
725     return ret;
726   }
727 
728   /** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */
numArrayDimensions(Object o)729   private static int numArrayDimensions(Object o) {
730     Class<?> c = o.getClass();
731     int i = 0;
732     while (c.isArray()) {
733       c = c.getComponentType();
734       i++;
735     }
736     return i;
737   }
738 
739   /**
740    * Fills in the remaining entries in the shape array starting from position {@code dim} with the
741    * dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have
742    * sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise.
743    */
fillShape(Object o, int dim, long[] shape)744   private static void fillShape(Object o, int dim, long[] shape) {
745     if (shape == null || dim == shape.length) {
746       return;
747     }
748     final int len = Array.getLength(o);
749     if (len == 0) {
750       throw new IllegalArgumentException("cannot create Tensors with a 0 dimension");
751     }
752     if (shape[dim] == 0) {
753       shape[dim] = len;
754     } else if (shape[dim] != len) {
755       throw new IllegalArgumentException(
756           String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
757     }
758     for (int i = 0; i < len; ++i) {
759       fillShape(Array.get(o, i), dim + 1, shape);
760     }
761   }
762 
763   /** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */
objectCompatWithType(Object obj, DataType dtype)764   private static boolean objectCompatWithType(Object obj, DataType dtype) {
765     Class<?> c = baseObjType(obj);
766     DataType dto = dataTypeFromClass(c);
767     int nd = numDimensions(obj, dto);
768     if (!c.isPrimitive() && c != String.class && nd != 0) {
769       throw new IllegalArgumentException(
770           "cannot create non-scalar Tensors from arrays of boxed values");
771     }
772     if (dto.equals(dtype)) {
773       return true;
774     }
775     if (dto == DataType.STRING && dtype == DataType.UINT8) {
776       return true;
777     }
778     return false;
779   }
780 
throwExceptionIfTypeIsIncompatible(Object o)781   private void throwExceptionIfTypeIsIncompatible(Object o) {
782     final int rank = numDimensions();
783     final int oRank = numDimensions(o, dtype);
784     if (oRank != rank) {
785       throw new IllegalArgumentException(
786           String.format(
787               "cannot copy Tensor with %d dimensions into an object with %d", rank, oRank));
788     }
789     if (!objectCompatWithType(o, dtype)) {
790       throw new IllegalArgumentException(
791           String.format(
792               "cannot copy Tensor with DataType %s into an object of type %s",
793               dtype.toString(), o.getClass().getName()));
794     }
795     long[] oShape = new long[rank];
796     fillShape(o, 0, oShape);
797     for (int i = 0; i < oShape.length; ++i) {
798       if (oShape[i] != shape()[i]) {
799         throw new IllegalArgumentException(
800             String.format(
801                 "cannot copy Tensor with shape %s into object with shape %s",
802                 Arrays.toString(shape()), Arrays.toString(oShape)));
803       }
804     }
805   }
806 
allocate(int dtype, long[] shape, long byteSize)807   private static native long allocate(int dtype, long[] shape, long byteSize);
808 
allocateScalarBytes(byte[] value)809   private static native long allocateScalarBytes(byte[] value);
810 
allocateNonScalarBytes(long[] shape, Object[] value)811   private static native long allocateNonScalarBytes(long[] shape, Object[] value);
812 
delete(long handle)813   private static native void delete(long handle);
814 
buffer(long handle)815   private static native ByteBuffer buffer(long handle);
816 
dtype(long handle)817   private static native int dtype(long handle);
818 
shape(long handle)819   private static native long[] shape(long handle);
820 
setValue(long handle, Object value)821   private static native void setValue(long handle, Object value);
822 
scalarFloat(long handle)823   private static native float scalarFloat(long handle);
824 
scalarDouble(long handle)825   private static native double scalarDouble(long handle);
826 
scalarInt(long handle)827   private static native int scalarInt(long handle);
828 
scalarLong(long handle)829   private static native long scalarLong(long handle);
830 
scalarBoolean(long handle)831   private static native boolean scalarBoolean(long handle);
832 
scalarBytes(long handle)833   private static native byte[] scalarBytes(long handle);
834 
readNDArray(long handle, Object value)835   private static native void readNDArray(long handle, Object value);
836 
837   static {
TensorFlow.init()838     TensorFlow.init();
839   }
840 }
841